Create a custom federated data set in TensorFlow Federated
At a very high-level, to use an arbitrary dataset with TFF the following steps are needed:
- Partition the dataset into per client subsets (how to do so is a much larger question)
- Create a tf.data.Dataset per client subset
- Pass a list of all (or a subset) of the Dataset objects to the federated optimization.
What is happening in the tutorial
The Federated Learning for Image Classification tutorial uses tff.learning.build_federated_averaging_process to build up a federated optimization using the FedAvg algorithm.
In that notebook, the following code is executing one round of federated optimization, where the client datasets are passed to the process' .next
method:
state, metrics = iterative_process.next(state, federated_train_data)
Here federated_train_data
is a Python list
of tf.data.Dataset
, one per client participating in the round.
The ClientData object
The canned datasets provided by TFF (under tff.simulation.datasets) are implemented using the tff.simulation.ClientData interface, which manages the client → dataset mapping and tff.data.Dataset
creation.
If you're planning to re-use a dataset, implementing it as a tff.simulation.ClientData
may make future use easier.