Model Training (image data)

This tutorial focuses on training a deep learning model against image data using an advanced privacy-preserving learning algorithm named Blind Learning. Blind Learning is one of our innovative algorithms which enables training advanced models without compromising the privacy of the data or the intellectual property of the model. Using Blind Learning, you can efficiently train your model on distributed datasets without ever "seeing" the data.

As a deep learning practitioner, you still have complete control over all phases of the model-training lifecycle, including:

  • Data collection and preparation
  • Model creation (architecture and hyperparameters)
  • Model validation
  • Inference


This tutorial trains a neural network to classify hand-written digits using the famous MNIST dataset. However, for illustration we will pretend that you do not own such a dataset and it is not freely available online. In practice, this is true for many real-world sensitive datasets, such as financial and healthcare data, or other highly-valued datasets such as customer behavior. We will use a copy of the MNIST dataset shared by another organization.

You will play the role of a deep learning practitioner who wishes to build your own classifier using this dataset. You will be able to build and train an image classifying model maintaining total privacy of data. The training will also preserve the intellectual property of your model's parameters and architecture.

Data Discovery

You can find the datasets available for data analysis and training using the 🔗Assets or directly from the code via the TripleBlind SDK. Please refer to the Access Points & Assets tutorial to learn more about finding and accessing assets.

You will use the web interface to search for the MNIST dataset and retrieve its UUID.

  • Visit the 🔗Gallery page
  • Search for "TUTORIAL - MNIST"
  • Open the dataset Details page and copy its ID
  • Paste the ID in the code below

The Dataset Details page also includes important information about the dataset, such as the number of images, size, number of channels, etc. as seen in the screenshot below.

import warnings
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import tripleblind as tb


#### access the dataset using its UUID
dataset = tb.Asset("27e6e5e6-c281-425e-8a82-06e3d0c8dcc9")

if not dataset.is_valid:
   print("Wrong UUID, or dataset not found")
   print(f"Successfully loaded: {}")

Data processing

TripleBlind's SDK provides a wide range of functions that allow you to preprocess data for any advanced data analysis task, including training deep neural networks.

In the following, we will use the ImagePreprocessor to apply a set of data transformations to prepare for our training task. For this simple example you can apply the following:

  • Create a preprocessor
  • Define the target_column. This is specific to our API, which lists the classes of all images in a CSV file.
  • Resize all images to 28x28.
  • Convert all images to grayscale
  • Set images dimensions to channels_first, i.e., (C, H, W)
  • Set the dtype of the numpy arrays representing the images to float32

ℹī¸ If your training task uses distributed datasets, you do not have to define a separate preprocessor for each dataset. All data processing will apply to all datasets automatically -- assuming the distributed datasets are horizontally distributed (i.e., they are of the same type and features).

pb = tb.preprocessor.image.ImagePreprocessor.builder()
pb.resize(28, 28)
pb.convert("L")  # use grayscale

Network definition

The SDK provides a NetworkBuilder class which supports assembling a network using simple methods to add layers and configurations, similar to PyTorch. After creating a NetworkBuilder you can use your editor's automatic completion function to find all supported layers or see the reference documentation for the full list of supported layers.

Notice in the network architecture below the command to add a special layer, builder.add_split(). The Split layer is specific to TripleBlind's Blind Learning algorithm. It enables training distributed deep learning models without having to share the datasets with the model creator. Instead, the model is split into two parts and distributed among data holders and the model creator. The Split layer ensures that only model parameters from a single layer are exchanged during the training process -- ensuring the privacy of the involved datasets.

ℹī¸ Placing the Split layer too early in the network will reduce the number of layers located at the data-owner side while increasing the computational needs at the server-side. We recommend placing the Split layer after the flatten_layer to balance both the privacy and the computational burdens on both the data owner and the algorithm provider.

#### Define the neural network we want to use for training
training_model_name = "example-mnist-network-trainer"

builder = tb.NetworkBuilder()
builder.add_conv2d_layer(1, 32, 3, 1)
builder.add_max_pool2d_layer(2, 2)
builder.add_conv2d_layer(32, 64, 3, 1)
builder.add_max_pool2d_layer(2, 2)


builder.add_dense_layer(1600, 128),
builder.add_dense_layer(128, 10)

training_model = tb.create_network(training_model_name, builder)

Training configuration

The SDK provides tb.create_job() to configure and create the training job. Within this function you configure the loss function, optimization algorithm, and other common hyperparameters in addition to parameters for privacy-preserving algorithms.

ℹī¸ The names of the loss functions and optimization algorithms, along with their parameters, are derived from PyTorch. You do not have to be familiar with PyTorch to use them, you only need to be aware of the functions' names such as CrossEntropyLoss for cross-entropy loss and Adam for the Adam optimizer.

The following code illustrates how to configure the loss function, optimizer, and the hyperparameters.

#### loss function names and parameters must be consistent with PyTorch
loss_name = "CrossEntropyLoss"

#### optimizer names and parameters must be consistent with PyTorch
optimizer_name = "Adam"
optimizer_params = {"lr": 0.001}

#### training hyperparameters
   "epochs": 2,
   "test_size": 0.2,   # splits the data into 20% test and 80% train
   "batchsize": 64,
   "loss_meta": {"name": loss_name},
   "optimizer_meta": {"name": optimizer_name, "params": optimizer_params},
   "data_type": "image",   # image/table
   "data_shape": [1, 28, 28],  # we use channels first because our training algorithm uses PyTorch
   "model_output": "multiclass",  # binary/multiclass/regression

Model training

After locating the data and training configuration, the training job can be created with datasets and all parameters. This function will automatically take care of training across distributed datasets (if more than one). The completed job returns the trained model as an asset.

ℹī¸ After running the following cell, you will notice that the job is waiting for permission. As discussed in previous tutorials, your training job will not start until permission is granted from the owners of the datasets that you wish to use in this job. For this tutorial, the dataset was given a wildcard Agreement which automatically allows any organization to access it.

#### train the network
job = tb.create_job(
   job_name=f"MNIST training",

if job.submit():

   if job.success:
       print(f"Trained Network Asset ID:")
       print(f"    ===============================================")
       print(f"    ===>  {job.result.asset.uuid} <===")
       print(f"    ===============================================")
       print(f"    Algorithm: Deep Learning Model")
       print(f"    Job ID:    {job.job_name}")

Download the trained model

#### Download trained model locally (this is a PyTorch model)

trained_network = job.result.asset

local_filename ="local_MNIST.pth", overwrite=True)
print(f"Trained network has been downloaded as: {local_filename}")
#### load the model and use it as a PyTorch object
pack = tb.Package.load(local_filename)
my_model = pack.model()

#### suppress the PyTorch "SourceChangeWarning"


Local inference

Now can make predictions using the trained model downloaded in the previous step.

You can perform a simple batch inference by creating a PyTorch loader and running the inferences using the model. See the Tabular Data tutorial for more on batch inference. Here, we will load a single image and run it through the model to overserve the model's output.

def image_loader(image_path):
   image =
   image = np.asarray(image)
   image = torch.from_numpy(image).float().unsqueeze(0)
   image = image.unsqueeze(0)
   return image

input_img = image_loader("img/four.jpg")

prediction = my_model(input_img).max(1)
print(f"The predicted label is: {prediction[1].item()}")

Behind the Scenes

The underlying model-training algorithm used in this tutorial is called Horizontal Blind Learning. "Horizontal" refers to the fact that the distributed datasets are of the same type and include the same columns (in the case of tabular data). In contrast, "Vertical" would refer to datasets of different features (e.g., columns) and possibly different data types across the organizations. See the figure below on Vertical versus Horizontal datasets.