Federated Learning in a Nutshell
Last updated
Last updated
In the artificial intelligence field, data is the cornerstone of training machine learning models. Generally, vast amounts of data are required to be uploaded to centralised servers, where the data can be further processed and used for training and fine-tuning the AI models of the future. Since the data might be sensitive, ensuring its privacy is imperative. Furthermore, with stricter laws and regulations for data privacy protection, this centralised approach, which can potentially leak data, is increasingly harder to implement.
Let’s take for instance an example extracted from a mundane activity: while driving your car, the navigation system is trying to recommend destinations based on your previous history. Yet again, this is a clear example of how AI can help us concentrate more on driving, then on setting up our navigation system. However, for training such a machine learning model, you (and everyone else) would need to upload all of your GPS coordinates to a centralised server where it can be processed. You might argue that this data is not personally identifiable, but, at the same time, if such data were leaked, it could pose a serious threat to your safety.
This is where federated learning comes into the picture, with its novel collaborative and decentralised approach. Federated learning is a machine learning technology in which multiple entities collaborate in training an AI model under the orchestration of a central server. As opposed to traditional machine learning in which all of the processing is done centrally, in FL many clients, ranging from mobile phones to entire organisations, are in charge of training the AI model locally. This means that data never has to leave the device, rather the AI model is distributed to all of the devices, where it is trained, as illustrated in Figure 15. Thereafter, all of the locally trained machine learning models are uploaded back to the central orchestrator, leaving no trace of the data that they were trained on. Finally, the orchestrator aggregates all of the locally trained models into a global AI model, based on each contribution.
Most often, clients in federated learning are edge servers - servers running on the edge of the network, closer to smartphones than any other centralised server on the Internet. The closer you are to a mobile device, the safer the data is, since it needs to travel through fewer hops along the way. An edge server is responsible for storing and processing the data of the mobile devices which are nearby. So the farther the server is, the more data it will hold. The more data a server holds, the more users can be affected in case of a data breach. Furthermore, the more data is distributed across a network, the harder it is for an attacker to track it all down.
Generally, federated learning is orchestrated by a centralised server which runs a given number of iterations, or rounds, of interacting with the clients that are responsible for training the model on their data. The orchestrator is also in charge of aggregating all of the local models into a global model. Each round in the learning process is composed of the following steps:
Initialisation: the orchestrator chooses the appropriate machine learning model that will be trained by the clients. Furthermore, the orchestrator distributes the initial global model to all clients.
Selection: a subset of the clients are selected per each round, while the others remain idle for the duration of the current round.
Training: all of the selected clients train their local machine learning models on the data that they own.
Aggregation: upon finishing the training round, all of the clients send their local models to the orchestrator to be aggregated into the updated global model. The orchestrator then proceeds to send model updates to all clients.
This process is executed iteratively, round after round, by the orchestrator until success criteria are met, and the global machine learning model is considered finalised. Based on the orchestrator, federated learning can be considered centralised, when a single node is in charge of orchestration, or fully-decentralised, in the case when the clients employ gossiping and consensus for orchestrating the process.
With a clear focus on ensuring user privacy, federated learning has major implications for various industries, with an emphasis on healthcare. The ability to predict clinical outcomes for patients without sacrificing their data ownership can pave the way for much needed breakthroughs. More companies are doing in-depth research into federating learning, leading to a predicted global market growth of $210 million by 2028, from the $127 million in 2023, with a CAGR of 10.6%.
Ideally, clients in federated learning should be as close to user data as possible. And the closest we can get is directly on smartphones, which are the source of user generated data and which have the technical capabilities required to train an AI model. Google has been employing federated learning for quite some time now for predicting text selections and keyboard prediction, and also provides means for Android users to configure their federated learning settings. However, as you can imagine, other companies (that are not invested in building the world’s most popular mobile operating system) do not have the means to reach users at the same scale and with the same given trust.