PyTorch How to get the shape of a Tensor as a list of int

As a data scientist working with PyTorch youll often find yourself needing to manipulate tensors Whether youre building neural networks or simply preprocessing data understanding the shape of your tensors is crucial In this post well explore how to get the shape of a PyTorch tensor as a list of integers

As a data scientist working with PyTorch, you’ll often find yourself needing to manipulate tensors. Whether you’re building neural networks or simply preprocessing data, understanding the shape of your tensors is crucial. In this post, we’ll explore how to get the shape of a PyTorch tensor as a list of integers.

Table of Contents

  1. What is a Tensor?
  2. Creating Tensors in PyTorch
  3. Getting the Shape of a Tensor
  4. Conclusion

What is a Tensor?

Before diving into the specifics of getting the shape of a tensor, let’s first define what a tensor is. In PyTorch, a tensor is a multi-dimensional array containing elements of a single data type. Tensors are the basic building blocks of PyTorch and are used for everything from representing input data to storing model parameters.

Tensors can be created in a variety of ways, including from Python lists, NumPy arrays, or by using PyTorch’s built-in functions. Here’s an example of creating a tensor from a Python list:

import torch

my_list = [1, 2, 3, 4, 5]
my_tensor = torch.tensor(my_list)

In this example, we create a tensor called my_tensor from a Python list called my_list. The resulting tensor is a 1-dimensional tensor with 5 elements.

Getting the Shape of a Tensor

Now that we understand what a tensor is, let’s explore how to get the shape of a tensor. In PyTorch, the shape of a tensor refers to the number of elements along each dimension of the tensor. For example, a 2-dimensional tensor with 3 rows and 4 columns has a shape of (3, 4).

To get the shape of a tensor in PyTorch, we can use the size() method. This method returns a torch.Size object, which is a subclass of Python’s built-in tuple type. While the torch.Size object contains the same information as a regular tuple, it has some additional methods that are useful when working with tensors.

import torch

my_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(my_tensor.size())

In this example, we create a 2-dimensional tensor called my_tensor with 2 rows and 3 columns. We then print the shape of the tensor using the shape method, which outputs (2, 3).

While the size() method returns a torch.Size object, we often need the shape of a tensor as a list of integers. To convert the torch.Size object to a list of integers, we can use the list() method.

import torch

my_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
shape = my_tensor.size()
shape_list = list(shape)
print(shape_list)

In this example, we create a 2-dimensional tensor called my_tensor with 2 rows and 3 columns. We then get the size of the tensor using the size() method and convert it to a list of integers using the list() method. The resulting output is [2, 3].

When working with PyTorch tensors, understanding their dimensions is crucial, while the size() method is commonly used to retrieve the size of a tensor, there is an alternative method called .shape that offers a concise and direct approach.

The following example demonstrates the use of .shape to obtain the dimensions of a tensor:

shape_alternative = my_tensor.shape
shape_list_alternative = list(shape_alternative)
print(shape_list_alternative)

In this case, instead of using the size() method, we access the tensor’s shape directly by utilizing the .shape attribute. The shape is then converted to a list of integers, providing a convenient and straightforward way to understand the tensor’s structure.

Conclusion

In this post, we’ve explored how to get the shape of a PyTorch tensor as a list of integers. We’ve seen how to use the size() method to get the size of a tensor as a torch.Size object, and how to convert that object to a list of integers using the list() method.

Understanding the shape of your tensors is essential when working with PyTorch, and knowing how to get the shape of a tensor is a fundamental skill for any data scientist. By following the techniques outlined in this post, you’ll be well-equipped to manipulate tensors in your PyTorch projects.


About Saturn Cloud

Saturn Cloud is your all-in-one solution for data science & ML development, deployment, and data pipelines in the cloud. Spin up a notebook with 4TB of RAM, add a GPU, connect to a distributed cluster of workers, and more. Request a demo today to learn more.