PyTorch How to get the shape of a Tensor as a list of int
As a data scientist working with PyTorch, you’ll often find yourself needing to manipulate tensors. Whether you’re building neural networks or simply preprocessing data, understanding the shape of your tensors is crucial. In this post, we’ll explore how to get the shape of a PyTorch tensor as a list of integers.
Table of Contents
What is a Tensor?
Before diving into the specifics of getting the shape of a tensor, let’s first define what a tensor is. In PyTorch, a tensor is a multi-dimensional array containing elements of a single data type. Tensors are the basic building blocks of PyTorch and are used for everything from representing input data to storing model parameters.
Tensors can be created in a variety of ways, including from Python lists, NumPy arrays, or by using PyTorch’s built-in functions. Here’s an example of creating a tensor from a Python list:
import torch
my_list = [1, 2, 3, 4, 5]
my_tensor = torch.tensor(my_list)
In this example, we create a tensor called my_tensor
from a Python list called my_list
. The resulting tensor is a 1-dimensional tensor with 5 elements.
Getting the Shape of a Tensor
Now that we understand what a tensor is, let’s explore how to get the shape of a tensor. In PyTorch, the shape of a tensor refers to the number of elements along each dimension of the tensor. For example, a 2-dimensional tensor with 3 rows and 4 columns has a shape of (3, 4)
.
To get the shape of a tensor in PyTorch, we can use the size()
method. This method returns a torch.Size
object, which is a subclass of Python’s built-in tuple
type. While the torch.Size
object contains the same information as a regular tuple, it has some additional methods that are useful when working with tensors.
import torch
my_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(my_tensor.size())
In this example, we create a 2-dimensional tensor called my_tensor
with 2 rows and 3 columns. We then print the shape of the tensor using the shape
method, which outputs (2, 3)
.
While the size()
method returns a torch.Size
object, we often need the shape of a tensor as a list of integers. To convert the torch.Size
object to a list of integers, we can use the list()
method.
import torch
my_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
shape = my_tensor.size()
shape_list = list(shape)
print(shape_list)
In this example, we create a 2-dimensional tensor called my_tensor
with 2 rows and 3 columns. We then get the size of the tensor using the size()
method and convert it to a list of integers using the list()
method. The resulting output is [2, 3]
.
When working with PyTorch tensors, understanding their dimensions is crucial, while the size()
method is commonly used to retrieve the size of a tensor, there is an alternative method called .shape
that offers a concise and direct approach.
The following example demonstrates the use of .shape
to obtain the dimensions of a tensor:
shape_alternative = my_tensor.shape
shape_list_alternative = list(shape_alternative)
print(shape_list_alternative)
In this case, instead of using the size()
method, we access the tensor’s shape directly by utilizing the .shape
attribute. The shape is then converted to a list of integers, providing a convenient and straightforward way to understand the tensor’s structure.
Conclusion
In this post, we’ve explored how to get the shape of a PyTorch tensor as a list of integers. We’ve seen how to use the size()
method to get the size of a tensor as a torch.Size
object, and how to convert that object to a list of integers using the list()
method.
Understanding the shape of your tensors is essential when working with PyTorch, and knowing how to get the shape of a tensor is a fundamental skill for any data scientist. By following the techniques outlined in this post, you’ll be well-equipped to manipulate tensors in your PyTorch projects.
About Saturn Cloud
Saturn Cloud is your all-in-one solution for data science & ML development, deployment, and data pipelines in the cloud. Spin up a notebook with 4TB of RAM, add a GPU, connect to a distributed cluster of workers, and more. Request a demo today to learn more.
Saturn Cloud provides customizable, ready-to-use cloud environments for collaborative data teams.
Try Saturn Cloud and join thousands of users moving to the cloud without
having to switch tools.