Batch Inference with PyTorch

Perform classification inference on a large sample fast on GPU, using PyTorch and Dask

In this example, we will perform a large volume image classification task on a GPU cluster with a pre-built model. We’re using the Stanford Dogs Dataset, so we will provide an image of a dog, and ask Resnet50 to give us the correct breed label.

Dataset: Stanford Dogs
Model: Resnet50

from dask_saturn import SaturnCluster
from dask.distributed import Client
from torchvision import datasets, transforms, models
import re
import s3fs
import dask
import toolz
import torch
import matplotlib.pyplot as plt

#### Load label targets ####
s3 = s3fs.S3FileSystem(anon=True)
with's3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt') as f:
    classes = [line.strip() for line in f.readlines()]

Assigning Objects to GPU Resources

If you are going to run any processes on GPU resources in a cluster, you need all your objects to be explicitly told this. Otherwise, it won’t seek out GPU resources. However, if you use a functional setup (as we are going to do later) you’ll need to do this INSIDE your function. Our architecture below will have all that written in.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Create and Connect Cluster

This will create and start a cluster if you haven’t already created one. If you need more information about working with clusters in Dask, we have reference material to help.

cluster = SaturnCluster(
    n_workers = 6, 
    scheduler_size = 'medium', 
    worker_size = 'g4dn4xlarge',
    nthreads = 16

client = Client(cluster)

Notice that this will wait until at least three of the requested six clusters is ready to use.

Preprocessing Images

Our goal here is to create a nicely streamlined workflow, including loading, transforming, batching, and labeling images, which we can then run in parallel. Our data preprocessing function is set up to take images from S3, load and transform them, and extract the label and the unique identifier from the filepath.

def preprocess(path, fs=__builtins__):
    Ingest images directly from S3, apply transformations,
    and extract the ground truth and image identifier. Accepts
    a filepath. 
    transform = transforms.Compose([

    with, 'rb') as f:
        img ="RGB")
        nvis = transform(img)

    truth ='dogs/Images/n[0-9]+-([^/]+)/n[0-9]+_[0-9]+.jpg', path).group(1)
    name ='dogs/Images/n[0-9]+-[a-zA-Z-_]+/(n[0-9]+_[0-9]+).jpg', path).group(1)
    return [name, nvis, truth]

This function does a number of things for us.

  • Open an image file from S3
  • Apply transformations to image
  • Retrieve a unique identifier for the image
  • Retrieve the ground truth label for the image

But you’ll notice that this has a @dask.delayed decorator, so we can queue it without it running immediately when called. Because of this, we can use some list comprehension strategies to create our batches and get them ready for our inference.

If you’re new to the idea of delayed functions, you might want to visit our deep dive article about Dask concepts.

This function restructures our data to make it work better with PyTorch’s expectations.

def reformat(batch):
    Structure the dataset to allow inference in batches.
    flat_list = [item for item in batch]
    tensors = [x[1] for x in flat_list]
    names = [x[0] for x in flat_list]
    labels = [x[2] for x in flat_list]
    tensors = torch.stack(tensors).to(device)
    return [names, tensors, labels]

This diagram explains what the process is doing- we go from three item wide lists initially, to long lists of a single data type (tensors, names, or labels).

Diagram showing lists being converted to long

Evaluation Functions

This takes our results from the model, and a few other elements, to return nice readable predictions and the probabilities the model assigned. From here, we’re nearly done! We want to pass our results back to S3 in a tidy, human readable way, so the rest of the function handles that. It will iterate over each image because these functionalities are not batch handling. is_match is one of our custom functions, which you can check out below.

def evaluate_pred_batch(batch, gtruth, classes):
    Accepts batch of images, returns human readable predictions. 
    _, indices = torch.sort(batch, descending=True)
    percentage = torch.nn.functional.softmax(batch, dim=1)[0] * 100
    preds = []
    labslist = []
    for i in range(len(batch)):
        pred = [(classes[idx], percentage[idx].item()) for idx in indices[i][:1]]

        labs = gtruth[i]
    return(preds, labslist)

def is_match(label, pred):
    Evaluates human readable prediction against ground truth.
    if'_', ' '), str(pred).replace('_', ' ')):
        match = True
        match = False

Inference Workflow Function

Now, we aren’t going to patch together all these computations by hand, instead we have assembled them in one single delayed function that will do the work for us. Importantly, we can then map this across all our batches of images across the cluster!

PyTorch has the companion library torchvision which gives us access to a number of handy tools, including copies of popular models like Resnet. You can learn more about the available models in the torchvision documentation.

def run_batch_to_s3(iteritem):
    Accepts iterable result of preprocessing, generates inferences and evaluates. 
    names, images, truelabels = iteritem
    with's3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt') as f:
        classes = [line.strip() for line in f.readlines()]

    # Retrieve, set up model
    resnet = models.resnet50(pretrained=True)
    resnet =

    with torch.no_grad():
        pred_batch = resnet(images)
        #Evaluate batch
        preds, labslist = evaluate_pred_batch(pred_batch, truelabels, classes)

        #Organize prediction results
        outcomes = []
        for j in range(0, len(images)):
            match = is_match(labslist[j], preds[j])            
            outcome = {'name': names[j], 'ground_truth': labslist[j], 
                       'prediction': preds[j], 'evaluation': match}

Run Inference Workflow

Notice that we’re going to use client methods below to ensure that our tasks are distributed across the cluster, run, and then retrieved.

Preprocess image data (lazy evaluation)

s3fpath = 's3://saturn-public-data/dogs/Images/*/*.jpg'
batch_breaks = [list(batch) for batch in toolz.partition_all(80, s3.glob(s3fpath))]
image_batches = [[preprocess(x, fs=s3) for x in y] for y in batch_breaks]
image_batches = [reformat(result) for result in image_batches] 

Run inference workflow on batches

futures =, image_batches) 
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)

results = []
errors = []

## Error handling allows us to capture when an input image is broken
for fut in futures_computed:
        result = fut.result()
    except Exception as e:

Evaluate Output

In this section, we’ll review the results our inference job produced. This calculates which predictions were true or false, and calculates the percent accurate.

true_preds = [x['evaluation'] for x in results if x['evaluation'] == True]
false_preds = [x['evaluation'] for x in results if x['evaluation'] == False]

Lessons Learned

  • You can apply @dask.delayed to your custom code to allow parallelization with nearly zero refactoring
  • Objects that are needed for a parallel task on GPU need to be assigned to a GPU resource
  • Passing tasks to the workers uses mapping across the cluster for peak efficiency And, of course, having multiple workers makes the job a lot faster!

Need help, or have more questions? Contact us at:

We'll be happy to help you and answer your questions!