How to use PyTorch DataLoader: Custom Datasets, Transformations, and Efficient Techniques
Data is the backbone of any deep-learning model. Efficiently managing and processing this data is important for training models. Imagine we’re training a deep learning model on thousands, or even millions, of images—how do we ensure that our model gets the correct data in the right format and at the right speed? The right format ensures the data is compatible with the model, while the right speed prevents the GPU from idling, maximizing resource efficiency and reducing training time. This is where PyTorch’s DataLoader
comes into play. Let’s see what PyTorch DataLoader
is, how we can work with it, and how to create a custom dataset, and its data augmentation methods.
Introduction to PyTorch DataLoader
The DataLoader
in PyTorch is a powerful built-in class designed to handle loading and managing datasets. It acts as an interface between a dataset and the model, helping us feed data to the model in batches during the training or testing phases. The DataLoader
abstracts away a lot of the complexities associated with handling large datasets. It offers built-in batching, shuffling, and parallel data-loading features, which we’ll learn in the next section.
How PyTorch DataLoader Works
PyTorch DataLoader
works by wrapping around a dataset, whether it’s a built-in PyTorch dataset (like MNIST or CIFAR-10) or a custom one. It allows us to iterate through the dataset in a manner that’s both memory and time-efficient. The key functions of the DataLoader
include:
- Batching Rather than loading one sample at a time,
DataLoader
allows us to load data in batches, taking advantage of GPUs’ parallel computing capabilities and enabling us to speed up the training process. - Shuffling: An
epoch
refers to one complete cycle through the entire training dataset by the learning algorithm. Shuffling the data at the beginning of eachepoch
helps prevent the model from memorizing the data order. This is particularly useful in ensuring the model generalizes well to unseen data. - Parallel Loading Using multiple workers(number of subprocesses),
DataLoader
can load data in parallel. This ensures the data loading process doesn’t become a bottleneck when working with large datasets. - Collate Function The
DataLoader
also allows us to define a custom function to combine the data points into batches.
The DataLoader
takes data help from a dataset object to get the index of the records to read. We often train models on our custom dataset, so we need to create our own dataset object.
Creating Custom Datasets in PyTorch
In PyTorch, the Dataset
class is the primary tool for handling data. It acts as an interface that allows us to define how our data is accessed from files, APIs, or even generated from scratch. The Dataset
class is part of the torch.utils.data
module and helps prepare data for training by abstracting the complexities of data loading.
The two primary methods you need to implement when creating a custom dataset are:
__len__ ()
: It defines the total number of samples in our dataset.__getitem__()
: This retrieves a specific data sample by index.
We’ll see how to override these two methods and create a simple custom dataset from a CSV file.
Step-by-Step Guide on Creating a Simple Custom Dataset
We’ll create a custom dataset from a CSV file that contains three columns: two feature columns (feature1, feature2) and one label column (label).
The data in the CSV file is as follows:
feature1,feature2,label1.0,2.0,02.0,3.0,13.0,4.0,04.0,5.0,1
This dataset has two features and one label for each row. We’ll build a custom dataset class to load this data.
Step 1: Import Required Libraries
First, import the essential libraries for managing datasets in PyTorch: torch
for tensors, Dataset
and DataLoader
for custom datasets and batching, and pandas
for data manipulation. Ensure these libraries are installed in your Python environment before proceeding.
import torchfrom torch.utils.data import Dataset, DataLoaderimport pandas as pd
Step 2: Define a Simple Custom Dataset Class
We’ll create a custom class that loads the CSV data using pandas and then returns the features and labels as PyTorch tensors.
class SimpleCSVLoader(Dataset):def __init__(self, csv_file):# Load the CSV file using pandasself.data = pd.read_csv(csv_file)def __len__(self):# Return the total number of samples in the datasetreturn len(self.data)def __getitem__(self, idx):# Extract the row at the given indexrow = self.data.iloc[idx]# Get the features (first two columns) and label (last column)features = row[['feature1', 'feature2']].values.astype(float)label = row['label']# Convert to PyTorch tensorsfeatures = torch.tensor(features, dtype=torch.float32)label = torch.tensor(label, dtype=torch.long)return features, label
In this class:
__init__()
loads the CSV file into memory.__len__()
returns the number of rows in the dataset.__getitem__()
takes an index (idx
), retrieves the features and labels from the CSV, and converts them into tensors.
Step 3: Create the Custom Dataset
Let’s create an instance of the SimpleCSVLoader
class using our CSV file. This class is designed to read and manage data stored in a CSV file and convert it into PyTorch tensors.
# Create an instance of the custom datasetdataset = SimpleCSVLoader(csv_file='sample_data.csv')
Step 4: Use DataLoader
for Batching and Shuffling
Next, we pass our custom dataset to the DataLoader
for batching and shuffling.
# Create a DataLoader for batchingdataloader = DataLoader(dataset, batch_size=2, shuffle=True)
Here, the batch_size=2
means the DataLoader
will load 2 samples at a time, and shuffle=True
ensures the data is shuffled before each epoch.
Step 5: Iterating Through DataLoader
We can now iterate through the DataLoader
to retrieve the batches of data during training.
for batch_idx, (features, labels) in enumerate(dataloader):print(f"Batch {batch_idx + 1}")print(f"Features:\n{features}")print(f"Labels:\n{labels}\n")
This is a loop to iterate through the dataloader
, retrieving one batch at a time:
batch_idx
is the batch index.- (
features
,labels
) is a tuple where features contain the input data (e.g.,feature1
andfeature2
), and labels contain the target labels for that batch.
The output would look something like this (the order might differ due to shuffling):
Batch 1Features:tensor([[2., 3.],[1., 2.]])Labels:tensor([1, 0])Batch 2Features:tensor([[4., 5.],[3., 4.]])Labels:tensor([1, 0])
The custom dataset loads data from a CSV file and returns the features and labels for each sample. The DataLoader
batches and shuffles the data which makes it ready for use in model training. This basic structure is enough to get started with custom datasets in PyTorch. We can extend it as needed for more complex datasets.
Data Transformations for PyTorch Models
In machine learning, particularly in deep learning, data transformation is an essential preprocessing step that prepares raw data for training models. Transformations help improve the quality of input data and make it more suitable for model training.
Transformations like resizing, converting images to tensors, or normalizing pixel values are common for image data. These transformations help the model to see data in a consistent, well-scaled format.
PyTorch provides a library called torchvision.transforms
, which is specifically designed for image data transformation and augmentation. It includes a set of tools to perform operations like resizing, cropping, and normalization.
Here are three of the most common transformations provided by torchvision.transforms
:
Resizing
Resizing is often necessary because most neural networks expect input images of a fixed size. torchvision.transforms.Resize()
allows us to change the dimensions of an image to the required size.
Example:
from torchvision import transformstransform = transforms.Resize((128, 128)) # Resizes images to 128x128 pixels
This ensures that all images in the dataset are the same size, which is necessary for batch processing.
Normalization
Normalization helps adjust the pixel values of an image so that they fall within a specific range (usually [-1, 1] or [0, 1]). It standardizes the data, making it easier for the model to learn patterns.
Example:
transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
This will normalize the pixel values for each channel (R, G, and B) by subtracting the mean and dividing it by the standard deviation, ensuring that the data is centred around zero with a standard deviation of one.
Tensor Conversion
In PyTorch, models operate on tensors, so images (or any data) need to be converted into tensors before they can be fed into a model. transforms.ToTensor()
converts an image into a PyTorch tensor and scales the pixel values from a range of [0, 255] to [0, 1] by dividing it by 255.
Example:
transform = transforms.ToTensor()
This is typically one of the first transformations applied, as it allows the image data to be compatible with PyTorch’s models.
Applying Transformations in the Dataset Class
Now, let’s apply transformations (resizing, normalization, and tensor conversion) to the custom dataset we created earlier using PyTorch’s torchvision.transforms
.
First, we’ll import the necessary classes into our code:
import torchfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsimport pandas as pd
Dataset
andDataLoader
: Utilities from PyTorch for loading and managing datasets.Dataset
defines how to access our data, whileDataLoader
handles batching, shuffling, and loading data efficiently.torchvision.transforms
: Provides common data transformation functions, specifically for image preprocessing and augmentation.pandas
: Used to load and manipulate data from CSV files in this example.
Custom Dataset Class with Transformations
When working with custom datasets in PyTorch, it’s often necessary to define our own dataset
class. This allows us to customize how data is loaded, preprocessed, and accessed. PyTorch provides the dataset class as a base, which we can use to define our custom dataset.
The example below defines a class SimpleCSVLoader
, which is tailored for loading data from a CSV file. This class also incorporates an optional transformation pipeline to preprocess the data, making it flexible for various use cases like data augmentation or normalization.
class SimpleCSVLoader(Dataset):def __init__(self, csv_file, transform=None):self.data = pd.read_csv(csv_file) # Load CSV data into a pandas DataFrameself.transform = transform # Store any transformations passed to the class
- The above code defines a custom dataset class, which inherits from PyTorch’s Dataset.
__init__(self, csv_file, transform=None)
: The constructor method takes two arguments:csv_file
: The path to the CSV file, which is loaded into a pandas DataFrame (self.data).transform
: This is optional and stores the transformation pipeline (like resizing, normalization, etc.). If no transformations are provided, the transform is set to None.
This sets up the class to load data and optionally apply transformations.
Dataset Length Method
In PyTorch, the __len__
method is required for any custom dataset class. It tells PyTorch the total number of samples in the dataset, which is important for batching and iterating through the data. This method allows the DataLoader
to determine how many batches can be created from the dataset.
def __len__(self):return len(self.data) # Return the total number of rows (data points) in the dataset
This method returns the number of rows in the self.data
. The method is required because PyTorch needs to know how many data points to expect when iterating through the dataset.
Getting a Data Sample
The __getitem__
method defines how to retrieve a single data sample (or row) from the dataset when accessed using an index. This method is invoked during data loading to supply individual samples to the model, making it crucial for efficient and accurate data processing.
Here’s an implementation of the __getitem__
method:
def __getitem__(self, idx):row = self.data.iloc[idx] # Select a specific row from the dataset using its indexfeatures = row[['feature1', 'feature2']].values.astype(float) # Extract the first two columns as featureslabel = row['label'] # Extract the label from the last columnfeatures = torch.tensor(features, dtype=torch.float32) # Convert features to a PyTorch tensorlabel = torch.tensor(label, dtype=torch.long) # Convert label to a tensor (as a long integer for classification)if self.transform:features = self.transform(features) # Apply the transformation if it's definedreturn features, label # Return the features and label
The code above defines how individual samples are processed and returned. Here’s how the method works:
__getitem__(self, idx)
: This method retrieves a single data point (row) from the dataset based on the given index (idx
).row = self.data.iloc[idx]
: Extracts the row with the specified index from theDataFrame
.features
: Extracts the first two columns (feature1 and feature2) as input features, converts them to a NumPy array, and then casts them to afloat
type.label
: Extracts the value from the label column (the target variable).- Tensor Conversion: Both the features and labels are converted into PyTorch tensors so a PyTorch model can process them.
- Transformation Application: If transformations were provided, they are applied to the features using
self.transform(features)
.
This ensures that each data sample is converted into tensors and transformed before being returned to the DataLoader
or any iterator accessing the dataset.
Defining Transformations
Transformations can be composed into a pipeline using transforms.Compose
, and they are applied sequentially to each sample.
transform = transforms.Compose([transforms.Resize((128, 128)), # Resize feature images to 128x128 pixelstransforms.ToTensor(), # Convert the image to a PyTorch tensortransforms.Normalize(mean=[0.5, 0.5], std=[0.5, 0.5]) # Normalize with mean and std dev])
We use transforms.Compose([...])
to do the following transformations:
transforms.Resize((128, 128))
: Resizes the features (typically images) to a fixed size of 128x128 pixels.transforms.ToTensor()
: Converts the image or data into a PyTorch tensor (required for PyTorch models).transforms.Normalize(mean=[0.5, 0.5], std=[0.5, 0.5])
: Normalize the pixel values to have a mean of 0.5 and a standard deviation of 0.5 for each channel. This is done to standardize the input data.
Creating an Instance of the Dataset
Now, we create an instance of our custom dataset class by providing the path to the CSV file and the transformation pipeline. This means that each time we access a data point, the defined transformations will be applied automatically.
dataset = SimpleCSVLoader(csv_file='simple_data.csv', transform=transform)
DataLoader
for Batching and Shuffling
The DataLoader
will handle the batching and shuffling of data, which is crucial for efficient training.
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
We pass the dataset instance to PyTorch’s DataLoader
:
batch_size=2
: Each batch will consist of 2 samples (features and labels).shuffle=True
: The data will be shuffled before each epoch, ensuring that the model doesn’t learn in the same order every time.
Iterating Through the DataLoader
For each batch, the features and labels are printed. We’ll create a loop that allows us to process each batch of data during model training:
for batch_idx, (features, labels) in enumerate(dataloader):print(f"Batch {batch_idx + 1}")print(f"Features: {features}")print(f"Labels: {labels}\n")
The code iterates through each batch of data from the DataLoader
, printing the batch index
, features
, and labels
. This helps inspect the data being passed to the model during training or evaluation. By understanding how data flows through the DataLoader, we can verify that our preprocessing steps are working as intended.
Building on this foundation, we will learn PyTorch Data Augmentation Techniques, which can significantly enhance model robustness by synthetically increasing dataset variability.
PyTorch Data Augmentation Methods
One of the major challenges when working with datasets, especially in computer vision, is preventing the model from overfitting. Overfitting happens when a model becomes too good at memorizing the exact examples it trained on. As a result, while the model performs well on familiar training data, it fails when given new, slightly different examples.
Data Augmentation addresses this issue by artificially increasing the diversity of the training data without actually collecting more samples. It applies various methods, such as flipping, rotating, cropping, and altering color properties, to create multiple variations of existing images.
These augmentations make the model more robust by teaching it to recognize patterns regardless of variations in orientation, position, or lighting conditions.
PyTorch provides several data augmentation techniques through its torchvision.transforms
module. Below are some of the most popular augmentation techniques:
Flipping
Flipping horizontally or vertically is one of the simplest and most effective augmentation techniques. It helps the model learn that objects can appear in different orientations.
- Horizontal Flip: Mirrors the image along the vertical axis.
- Vertical Flip: Mirrors the image along the horizontal axis.
In the following example, the image has a 50% chance of being flipped horizontally or vertically, adding variability to the dataset. The choice of (p=0.5
) 50% is typically used when we want the transformation to occur randomly and equally often.
import torchvision.transforms as transforms# Horizontal Flip with 50% probabilitytransform = transforms.RandomHorizontalFlip(p=0.5)# Vertical Flip with 50% probabilitytransform = transforms.RandomVerticalFlip(p=0.5)
Cropping
Random cropping allows the model to focus on different parts of the image. This is particularly useful in scenarios where objects may appear in various locations within the image. By randomly cropping the image, we effectively teach the model that important features may not always be centred.
# Randomly crop a 100x100 patch from the imagetransform = transforms.RandomCrop(size=(100, 100))
The above line of code randomly extracts a 100x100 patch from the original image, forcing the model to adapt to different viewpoints and spatial arrangements.
Rotation
Rotation augmentation applies random rotations to images. It allows the model to recognize objects in different orientations, which is useful for tasks like object detection or classification, where objects may not always be upright.
The following line randomly rotates the image by any angle between -30 and 30 degrees, encouraging the model to learn rotational invariance.
# Rotate the image by a random angle between -30 and 30 degreestransform = transforms.RandomRotation(degrees=(-30, 30))
Color Jitter
Color jitter modifies the brightness, contrast, saturation, and hue of the image. This helps the model generalize better to different lighting conditions, shadows, or color variations in real-world data.
# Adjust brightness, contrast, saturation, and hue randomlytransform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)
In this case, the image’s brightness, contrast, saturation, and hue are randomly adjusted within the specified ranges, making the model more resilient to color changes in the environment.
Efficient Data Loading Techniques in PyTorch
A common bottleneck when working with large datasets is when the data loading process is slower than the model’s ability to consume the data. This causes the GPU to remain idle while waiting for data to be loaded and preprocessed. This inefficiency can dramatically increase training time and waste computational resources.
Efficient data loading in PyTorch involves optimizing how data is fetched, processed, and sent to the GPU. In this section, we’ll understand strategies like utilizing multiple workers using num_workers
, prefetching and catching data, and using lazy loading techniques to ensure that our data pipeline keeps up with model training.
Utilizing Multiple Workers with num_workers
One of the most effective ways to speed up data loading is by parallelizing the process. PyTorch allows us to specify the number of worker processes that load data in parallel using the num_workers
parameter in the DataLoader
. Each worker is responsible for retrieving and processing a portion of the data.
By default, PyTorch uses a single worker to load data sequentially. Increasing the number of workers enables data to be fetched and processed in parallel, significantly speeding up the loading process.
from torch.utils.data import DataLoader# Assuming the dataset is already defineddataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
In this example, num_workers=4
will cause four worker processes to load data in parallel, reducing the GPU’s waiting time.
Prefetching and Caching Data
Prefetching is another way to ensure that the model always has data to process. The concept involves loading the next batch of data while the model is training on the current batch. This way, the CPU actively fetches data while the GPU processes the existing batch.
While PyTorch does not explicitly support prefetching at the framework level, we can simulate prefetching by increasing the num_workers
parameter and setting a larger prefetch_factor
.
dataloader = DataLoader(dataset, batch_size=64, num_workers=4, prefetch_factor=2)
prefetch_factor=2
: Each worker preloads two batches of data, so when the model finishes processing one batch, the next is ready for immediate use.
Handling Large Datasets with Lazy Loading
When working with large datasets that cannot fit into memory, lazy loading (also known as on-the-fly data loading) becomes important. Instead of loading the entire dataset into memory, PyTorch only loads the data needed for each batch during training. This is useful for high-resolution image datasets or large video datasets.
In PyTorch, lazy loading is inherently supported by the Dataset
class, where data is only accessed during each call to __getitem__()
.
In the following example, images are loaded only when needed, ensuring that memory is conserved while handling large datasets:
class LazyDataset(Dataset):def __init__(self, image_paths, transform=None):self.image_paths = image_pathsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):# Load the image from disk at runtimeimage = Image.open(self.image_paths[idx])if self.transform:image = self.transform(image)return image
The code contains __getitem__()
function, where images are loaded from disk at runtime as needed. The __len__()
function provides the dataset size, and only the necessary images are loaded into memory for each batch, conserving memory. The transform function is applied to the images if defined, allowing for on-the-fly data preprocessing.
By using techniques like using multiple workers (num_workers
), prefetching data, and employing lazy loading, we can ensure that our model is continuously supplied with data without unnecessary delays.
Wrapping up
Mastering these key concepts - efficient data loading, custom datasets, transformations, and augmentations – will help you to handle complex datasets and optimize model training. These tools are important for building scalable, high-performance deep learning models. We learned:
DataLoader
is a PyTorch class that efficiently manages data loading through batching, shuffling, and parallel processing.- Custom datasets require implementing two key methods:
__len__()
for total samples and__getitem__()
for accessing individual samples. - Data transformations like resizing, normalization, and tensor conversion prepare raw data for model training.
- Data augmentation techniques (flipping, rotation, cropping) help prevent overfitting by increasing dataset variability.
- Multiple workers (
num_workers
) enable parallel data loading to prevent GPU idle time. - Prefetching loads the next batch while the current batch is processing, improving pipeline efficiency.
- Lazy loading handles large datasets by only loading data when needed rather than all at once.
To learn more about PyTorch and how to build and train neural networks, you can enroll in Codecademy’s free Intro to PyTorch and Neural Networks course.
Author
'The Codecademy Team, composed of experienced educators and tech experts, is dedicated to making tech skills accessible to all. We empower learners worldwide with expert-reviewed content that develops and enhances the technical skills needed to advance and succeed in their careers.'
Meet the full teamRelated articles
- Article
Deep Learning Workflow
In this article, we cover the workflow for a deep learning project. - Article
Getting Started with PyTorch: A Beginner’s Guide to Deep Learning
Learn PyTorch with custom datasets, data transformations, augmentation techniques, efficient loading, and AI model building for seamless implementation.
Learn more on Codecademy
- Skill path
Code Foundations
Start your programming journey with an introduction to the world of code and basic concepts.Includes 5 CoursesWith CertificateBeginner Friendly4 hours - Career path
Full-Stack Engineer
A full-stack engineer can get a project done from start to finish, back-end to front-end.Includes 51 CoursesWith Professional CertificationBeginner Friendly150 hours
- Introduction to PyTorch DataLoader
- Creating Custom Datasets in PyTorch
- Step-by-Step Guide on Creating a Simple Custom Dataset
- Data Transformations for PyTorch Models
- Applying Transformations in the Dataset Class
- PyTorch Data Augmentation Methods
- Efficient Data Loading Techniques in PyTorch
- Wrapping up