Home Machine Learning Using PyTorch argmax function with examples

Using PyTorch argmax function with examples

The user can develop deep learning algorithms effectively with PyTorch’s various capabilities. One of the functions offered by PyTorch is argmax. We may obtain the indices of the tensor and the maximum value of the elements from the tensor by using the argmax function.

PyTorch has another function to return the second-highest items from the input tensor. However, there are occasions when we need to perform an algorithm quickly, or we claim that the argmax() function is used if we need to retrieve the highest element and indices from the input tensor.

The torch.argmax() function determines the indices of the elements in an input tensor with the highest values. Only the indices, not the element value, are returned. The function given is responsible for returning the index of the first maximal element if the input tensor has more than one maximal value. The torch.argmax() function is applicable in determining the indices of a tensor’s maximum values for each dimension.

PyTorch – argmax()

In this PyTorch argmax() article, we’ll demonstrate how to use argmax() to return the index positions of a tensor’s maximum values.

An open-source framework called PyTorch is offered together with the Python programming language. The data is transformed into a Tensor with PyTorch. The respective data is consequently kept in a multidimensional array called a tensor. Therefore, we must import the torch module to use a Tensor. Additionally, the tensor() is the method used to build a tensor.argmax()

The greatest index of each element in the input tensor is returned by PyTorch’s argmax() function. The argmax() syntax is as follows:

torch.argmax(tensor,dim,keepdim)

According to the syntax:

The input tensor is the tensor.

Dim means to make something smaller. If dim=0 is used, a column comparison will be performed to obtain the index for the maximum value along a column. If dim=1, a row comparison will be performed to obtain the index for the maximum value along a row.

Keepdim determines whether or not the output tensor’s dimension(dim) is kept.

# Start by importing the required library
import torch

# definition of an input tensor
var_input = torch.tensor([0., -1., 2., 8.])

# printing the above-defined tensor
print("The Input Tensor Val:", var_input)

# Creation of the indexes for the highest values.
var_indices = torch.argmax(input)

# printing the given indices
print("Indices:", var_indices)

The Python example above identifies the element’s index with the highest value from a 1D tensor input. The element’s index is 3, and its maximum value in the input tensor is 8.

Example: create a tensor with two dimensions

This demo seeks to create a tensor with two dimensions with six rows and eight columns and apply argmax() on rows and columns.

#start by importing the torch module

import torch

# 2 dimensions (6 * 8) tensor creation

# through elements that are random using the randn() function

var_data = torch.randn(6,8)

print(var_data)

#getting the maximum index along columns using the argmax

print(torch.argmax(data, dim=0))

#find the highest index along the rows using argmax

print(torch.argmax(data, dim=1))

Example: compute the condition number

With respect to the various matrix norms, we compute the condition number in this program.

# Bring the necessary library in
import torch

# define a tensor as an input
input = torch.randn(4,4)

# print the defined tensor above.
print("Input Tensor:", input)

#Create indexes for the highest values.
indices = torch.argmax(input)

# printing the indices
print("Indices:", indices)

# Create indices for dim 0's highest value.
indices = torch.argmax(input, dim=0)

# printing the indices
print("Indices in dim 0:", indices)

# Calculate the largest value in dim 1's indexes.
indices = torch.argmax(input, dim=1)

# printing the indices
print("Indices in dim 1:", indices)

In the Python example above, we identify the indices of the greatest value of an input 2D tensor element in various dimensions. You might see that the input tensor and indices are different because we used the torch.randn() function to create the input tensor’s elements.

Example: Tensor with a 8 * 8 matrix created with argmax()

 #start by importing the torch module

import torch

#creation of 2 dimensions (8 * 8) tensor

#with random elements using randn() function

data = torch.randn(8,8)

#display

print(data)

#with argmax, get the maximum index along the columns.
print(torch.argmax(data, dim=0))

#find the highest index along the rows using argmax
print(torch.argmax(data, dim=1))

Utilizing a CPU

The initial step is creating a tensor with a cpu() function. The latter is responsible for executing an argmax() function on the CPU. As a result, it operates on a computer with a CPU. The cpu() method is now available for usage when building tensors.

The syntax is as follows:

torch.tensor(data).cpu()

To illustrate this concept, we will use a 8 * 8 tensor matrix with cpu() and apply argmax() as demonstrated below:

#import torch module

import torch

#creation of  a tensor with two dimensions (8 * 8)

#with random elements using randn() function with cpu()

var_data = torch.randn(8,8).cpu()

print(var_data)

#get maximum index along columns with argmax

print(torch.argmax(var_data, dim=0))

#Get the highest index along the rows using argmax.
print(torch.argmax(var_data, dim=1))

Example: Using cpu() function to generate a tensor

This demo employs the cpu() function in generating a tensor with two dimensions with three rows and five columns, then applying argmax() to the rows and columns.

# Start by importing the necessary libraries
# import torch module

import torch

# creation of a tensor with two dimensions (3 * 5)

#with random elements using randn() with cpu() function

data = torch.randn(3,5).cpu()

#display

print(data)

#with argmax, get the maximum index along the columns.
print(torch.argmax(data, dim=0))

#Get the highest index along the rows using argmax.

print(torch.argmax(data, dim=1))

Example: using the cpu() function to generate a tensor with two dimensions

The example below uses the cpu() function to generate a tensor with two dimensions with three rows and five columns, then apply argmax() to the rows and columns.

# start by importing the torch module

import torch

#2 dimensional (5 * 8) tensor creation

#using the cpu() function and randn() with random elements

var_data = torch.randn(5,8).cpu()

print(var_data)

#with argmax, get the maximum index along the columns.
print(torch.argmax(var_data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

argmax Examples

For a better understanding, let’s look at some examples using the various argmax() functions.

import torch

i_tensor = torch.randn(3, 3)
max_ele = torch.argmax(i_tensor)
print(" input tensor:", i_tensor)
print("The max element from input tensor:", max_ele)

In the example above, we first import the torch as indicated and then use the randn() function to build the random tensor. Then, to obtain the maximum value that meets our requirements, we must use the argmax() function. Finally, we print the outcome.

Here is another illustration of the argmax() function.

import torch

i_tensor = torch.randn(3, 3)
max_ele = torch.argmax(i_tensor, dim=1)
print(" input tensor:", i_tensor)
print("The max element from input tensor:", max_ele)

In the example above, we first import the torch as indicated and then use the randn() function to build the random tensor. Then, to obtain the maximum value that meets our requirements, we must use the argmax() function. Finally, we just print the outcome. We also attempt to implement the indices in this.

The example above shows us how to implement the argmax() method. Using the example below, let’s examine the distinction between the max() and argmax() functions.

import torch

input = torch.randn([3, 4])
print(input)
max_ele, max_indic = torch.max(input, dim=1)
print(max_ele)
print(max_indic)

In the example above, we import the torch as shown and then use the randn() function to generate a tensor in an attempt to provide multiple dimensions with the max() function. The result is then printed after passing the input value to the max() method. We set the maximum dimension in this example to dim = 1, as displayed. Finally, we report the maximum number of elements and indices.

How 4-Dimensional Torch.argmax Operates

For a 4-Dimensional input tensor with shape [1,2,3,4] and axis=0, if we didn’t set the keepdims=True parameter in the argmax() method, it will return an output tensor with shape [2,3,4]. However, the argmax() method returns a tensor with shape [1,3,4] for axis=1, which will be identical to another axis. As a result, by default, when the argmax method is used across any axis or dimension, the axis or dimension is collapsed since all of its values are replaced by a single index.

If the keepdims parameter is set to True in the argmax() method, the dimension will not be removed but will remain as one. For instance, the argmax() method along the axis=1 returns a tensor with the shape [1,1,3,4] for a 4-D tensor of shape [1,2,3,4].

We created a 4-dimensional random tensor in the program below using the randn() function, passed it to the argmax() method, and then examined the results along the various axes with keepdims=False or None.

# start by importing the necessary libraries
import torch

# definition of a random 4D tensor
var_tensor = torch.randn(1, 2, 3, 4)
print("First Tensor:", var_tensor)
print(var_tensor .shape)

#apply the argmax method on the axis-0 of a 4D tensor
print('---Tensor of output along axis 0 ---')
print(torch.argmax(var_tensor, axis=0, keepdims=False))
print(torch.argmax(var_tensor, axis=0, keepdims=False).shape)

# apply the argmax method on the axis-2 of a 4D tensor
print('---Tensor of the output along axis two ---')
print(torch.argmax(var_tensor, axis=2))
print(torch.argmax(var_tensor, axis=2).shape)

Conclusion

In this PyTorch argmax() article, we learned what argmax() is and how to use it to retrieve the indices of the maximum values across columns and rows for a tensor. The Torch.argmax() method takes a tensor as input and returns the indices of the input tensor’s maximum values across a given dimension or axis. The function is responsible for returning the index of the first maximal element if the input tensor has more than one maximal value.

Additionally, we used the cpu() function to generate a tensor and returned maximum value indexes. When dim is set to 0, it returns indices of the highest values in all columns, whereas when it is set to 1, it returns indices of the highest values in all rows.

You may also like

Leave a Comment