Building Robust and Private Federated Deep Learning

What is Federated Learning trying to solve?

The power of Deep Learning resides in the richness and precision of data we use to train the deep networks. Certainly, now we are at a stage where we have more data than the compute power to utilize the data. Hence, we have come up with strategies where we try to leverage several compute instances in order to learn a model collectively from a larger amount of datasets.

But there is one problem - there exists spatial locality in the availability of the data. Data may reside with different entities, organizations, or companies along with that it may even be located at different geo-locations. For instance, the medical treatment of cancer has helped to generate a lot of body tissue images labeled by doctors as malignant or benign cancer.

As a Deep Learning practitioner, by reading the above sentence you would have been able to imagine the power associated with the scale of data that we might have if we consider a group of hospitals or researchers maintaining such data. However, for training the models, by conventional means, we must bring the data on one machine or a cluster. This becomes an obstacle, in scenarios where data size is huge (especially if data is composed of images) or there are privacy concerns attached to the data (which is very common with personal data).

In order to solve these issues community has come up with what is called Federated Learning.

Formally defining Federated Learning

Federated Learning in the hospital dataset setting (Reference: Federated Learning powered by NVIDIA Clara)

Federated Learning enables multiple learners to collaboratively learn a shared prediction model while keeping all the training data isolated within specific learner access thus, decoupling the ability to do machine learning from the need to store the data at a single place.

Hence, Federated Learning can be applied to the above-discussed setting of hospital data. Another application worth knowing is the collaborative on-device training of ML models on user mobile devices. This technique allows training of models without the need of taking user’s sensitive data outside the user’s mobile device.

The user’s phone trains the model using the user’s data (A). Many users’ updates are aggregated (B) to form a consensus change (C) to the shared model, after which the procedure is repeated. (Image Reference: Federated Learning: Collaborative Machine Learning without Centralized Training Data)

The training in this scenario takes place by downloading the current model on a mobile device, training the model from the data on the user’s phone, and then uploading the model weights to the server where it is immediately averaged with other user updates to improve the shared model. All the training data never leaves the user’s device! Voila!

Federated Learning allows for smarter models, lower latency, and less power consumption, all while ensuring privacy.

Diving Deeper on how to train models in a distributed setting (Paper 1)

In this section, we explore the configurations proposed in the SplitNN paper by Vepakomma et. al [1].

The SplitNN configurations facilitate distributed deep learning without sharing raw data and try to achieve success in training the models in the following aspects:

  1. entities holding different modalities of data
  2. centralized and local health entities collaborating on multiple tasks
  3. learning without sharing labels

The paper by Vepakomma et. al specifically focuses on the issue due to which the collaboration in health is heavily impeded - Lack of Trust. The main idea is to not share the data among the learners while making updates in the Deep Learning model together.

Some common characteristics to note before we proceed to look at the configurations:

  • Configuration consists of a server and multiple clients. A client is an entity that has access to a subset of data (vertical partitioning) and access to some compute as well.
  • During training, activations and weight update gradients will need to be propagated over the network among client and server
  • Data (x, y) resides on the client machines. Labels (y) may need to be transported over the network to reach the server in some configurations.
  • In all the configurations the model is split between client and server and the layer at which the model gets separated is called the cut layer (depicted in green color in the consecutive diagrams).

Simple vanilla configuration for split learning

Simple Vanilla Configuration (Reference: [1])

In this configuration, each client trains a partial deep network up to a specific layer known as the cut layer. The activations from the cut layer and corresponding labels are transported to the server over the network. The client looks only at the data they possess, compute activations. In this way, clients are never required to share the data (at least the x of the data sample) in order to train the network. We will resolve this limitation as well in the next configuration.

After forward propagation is completed, the server computes the loss using the labels received. Then loss is backpropagated through all the layers till cut layer on server. Later the gradients are transported back to the client over the network.

This procedure is continued until the distributed split learning network is trained without looking at each other's raw data.

U-shaped configurations for split learning without label sharing

U-shaped configuration (Reference: [1])

This configuration is devised to solve the issue of label sharing. In this setup, we again cut the network at the end layers of the server’s network and bring the final classification layer back to the client. This allows training without sharing the labels.

This setup is ideal for distributed deep learning where even the label data is considered sensitive.

Vertically partitioned data for split learning

Configuration for Vertically Partitioned Data (Reference: [1])

This configuration helps to learn distributed models in settings where multiple institutions (clients) hold different modalities of data for the same ith sample, without sharing data among clients.

Let’s say, there can be cases when one patient's data may reside with different entities — for example, Patient A’s certain specific information may reside with the radiology department and some other information of Patient A might reside with the pathology department. Now, if both the department come together in training a disease diagnosis model which requires a combination of the data from both departments (that is different modalities of data for the same sample are required to be combined) additionally, both departments are unwilling to share the data directly with each other. In such a scenario, the above configuration for vertically partitioned data may be employed.

Now, let’s discuss how the training operates in this configuration — Both the departments pre-share the batch size and the order of data (either the names of patients or hashed names of patients using a mutually agreed hashing function). Now both the clients perform forward propagation in their partial networks. The outputs at the cut layer from both these centers are then concatenated and sent to the server that trains the rest of the model. This process is continued back and forth to complete the forward and backward propagations in order to train the distributed deep learning model without sharing each other's raw data.

Note that this configuration can also be modified similar to the U-shaped configuration in order to avoid label sharing.

Experimental Evaluation of SplitNN

The authors of SplitNN paper [1] have validated the performance metrics in a simplified setting —

(i) VGG on CIFAR 10 on100 clients

(ii) Resnet-50 on CIFAR 100 on 500 clients

The authors compared the SplitNN architecture’s performance with large batch synchronous SGD and technique from paper [3] termed as “Federated Learning” in the comparison analysis.

Validation Accuracy vs Computation

Authors have shown a dramatic reduction in computational burden (in TFLOPS) while maintaining higher accuracies when training over a large number of clients with splitNN. Left Graph: 100 clients with VGG on CIFAR 10| Right Graph: 500 clients with Resnet-50 on CIFAR 100 (Reference: [1])

SplitNN approach is able to achieve much better accuracy by leveraging clients efficiently. This makes SplitNN a very promising contender in the Federated Learning domain.

Computation bandwidth analysis

Computation bandwidth required per client

From the above table, it’s worth noting that the Computation bandwidth required per client when training CIFAR 100 over ResNet is lower for splitNN than large batch SGD and federated learning [3] with a large number of clients. Whereas for setups with a smaller number of clients, federated learning requires lower bandwidth than splitNN.

Conclusion

SplitNN is dramatically resource-efficient in comparison to currently available distributed deep learning methods. SplitNN is also versatile in allowing for many plug-and-play configurations based on the required application.

Exploring model attack challenges (Paper 2)

In this section, we will discuss the paper ‘Analyzing Federated Learning through an Adversarial Lens’ by Bhagoji et. al. which explores the threat of model poisoning attacks in federated learning.

The paper demonstrates the extent of stealthiness by which an agent can maliciously affect the training without getting detected. The authors demonstrate this by considering a scenario of multiple learners (clients) and a server.

Model Training Setting

Multiple learners or clients perform model training and share only the parameter updates with a centralized parameter server.

Parameter Server architecture that the authors assume for the training setting (Image Reference: [4]Large Scale Distributed Deep Networks )

Targeted model poisoning (Approach 1 — Naive)

The authors consider the scenario when the model training attack is initiated by a single, non-colluding malicious agent among all the clients (the model replicas in the above images are being trained by each client having access to their datasets).

Adversarial Objective

Mathematically, the objective function for the adversary can be represented by the following equation:

where D_m is the original data and D_aux is the auxiliary dataset that the adversary wants to contaminate the model with. Auxiliary data set has r examples of (x_i, tau_i) with the adversary. w_G^t is the set of weights at the server, tau_i is the labels corresponding to the data x_i. The equation conveys the fact that the adversary wants to optimize over the D_aux only. [2]

To make the setup realistic, the authors set up the following flagging mechanisms to detect malicious activity.

  1. Accuracy Validation — The updated model is checked by the server using a validation set to see if the model after an update from any agent performs poorly on the validation set.
  2. Validation by Weight Update Statistics — Server pays close attention to the distribution of weight updates in order to detect anomalous behaviors among clients/agents.

These basic checks in the setup make the operation of the adversary a little competitive. Hence, from here the authors suggest a method on how adversaries should proceed in order to perform malicious updates despite the aforementioned checks at the server.

Analyzing at server in case of Targeted model poisoning

Let’s see what the flagging mechanisms observe when only the Adversarial Objective is optimized by the adversary.

(Left) Indicates the status of Validation Accuracy Test. (Right) Indicates weight statistics of Benign and Malicious (by adversary) updates

From the left graph, we observe that the adversary (malicious model) miserably fails on the Validation Set (blue line). Additionally, the weight statistics of the Malicious model (right graph) differs significantly from Benign model updates, hence it is clearly distinguishable that red weight updates are anomalous.

One can think upon improving the validation set accuracy by including the examples from the same distribution as the validation set (that is the distribution on which everyone is training). This gives a path to the next approach by the authors.

Stealthy model poisoning (Approach 2— Improvisation)

Now, in order to gain stealth under both flagging mechanisms, our new objective should take validation set testing and weight update statistics testing into consideration. Hence the previous objective function gets modified to the following.

Adversarial Objective with Stealth Objective

Adversarial Objective with Stealth Objective — delta_m is the gradient update by malicious agent and delta_ben is the gradient update by the benign agents

Note that the green term is the same from the targeted model poisoning approach (which indicates optimization over the auxiliary data set on which adversary wants the global model to predict malicious labels), however, the additional red term accounts for averting from flagging mechanisms.

The first term in red indicates optimization of loss over the examples from the same distribution as the validation set (that is the distribution on which every other learner is training). The second term in red accounts for keeping the weight updates restricted so that the malicious agent updates’ can pass the weight update statistics test.

Analyzing at server in case of Stealthy model poisoning

(Left) Indicates the status of Validation Accuracy Test. (Right) Indicates weight statistics of Benign and Malicious (by adversary) updates

The above-left graph indicates that the stealthy model poisoning performs a good validation set but fails on the weight statistics test as the two distributions in the right graph vary significantly.

The authors find that even the introduction of stealth objective in this manner is not able to pass the weight update statistics test, hence they propose an “alternating minimization” approach.

Alternating minimization model poisoning (Approach 3)

This is the proposed approach by authors where they show significant improvements in achieving stealthiness such that malicious agent goes undetected at the server and continues to optimize over its target set.

They propose to break the optimization procedure of Adversarial Objective with Stealth Objective into two steps such that the two steps correspond to red and green terms.
Step 1: Run optimization for p iterations to optimize over the following target objective

Step 2: Run optimization for q iterations to optimize over stealth objective

These two optimization steps for p and q iterations should be run alternatively one after the other. At the beginning of the experiment p and q can be set to 1 but as the experiment progresses and if stealthiness is stable but the target objective is lagging, then q can be increased. Similarly, in another case, if stealthiness is suffering p can be increased relatively.

In this way, the optimization strategy obtains finer control and achieves much better performance on the target poisoning along with maintaining stealthiness i.e. not getting caught in any of the flagging mechanisms.

Now, let’s see what the flagging mechanisms observe when an alternating minimization strategy is employed.

(Left) Indicates the status of Validation Accuracy Test. (Right) Indicates weight statistics of Benign and Malicious (by adversary) updates

From the above graph, we see that now the adversary model achieves validation accuracy (blue line) close enough to global average validation accuracy (green line) while achieving its poisoning objective (brown line) with full confidence. This indicates the adversary is able to pass the validation test check

Alongside, the weight update statistics of the adversary (red) in the right graph are much closer to the benign agents’ weight update statistics. Hence, in this approach model poisoning passes both the tests.

Conclusion

From the first paper — “Split learning for health: Distributed deep learning without sharing raw patient data” by Vepakomma et. al we explore the new techniques on how we can train models in a federated setting while keeping the data private among the data holding entities.

The second paper — “Analyzing Federated Learning through an Adversarial Lens” by Bhagoji et. al introduces the issue of model poisoning attacks which become a risk as soon as we go into the federated learning approach. Given the acceptance of Federated Learning to overcome the barriers of privacy, we are now set to face this new open problem of model poisoning.

References

[1]: Split learning for health: Distributed deep learning without sharing raw patient data, Vepakomma et al.

[2]: Analyzing Federated Learning through an Adversarial Lens

[3]: Communication-Efficient Learning of Deep Networks from Decentralized Data

[4]: Scaling Distributed Machine Learning with the Parameter Server by Mu Li et. al

CS Grad Student at Columbia University, specializing in Machine Learning