Introduction to emnist dataset code of federal learning classification

Project introduction

Federated learning algorithm is used to classify emnist data sets.

First of all, we should clarify a knowledge point, which is the modifier @ TFF tf_ Calculation and @ TFF federated_ The difference between computation:

  • tensorflow_ The federated (TFF) operation deals with federated values;
  • Each Federation value has a federation type, which includes type (tf.float32) and placement (tff.CLIENTS);
  • Federated values can be moved (client – > server), which must be through @ TFF federated_ Computation and a federation type are completed;
  • @tff. federated_ The type signatures of the function modified by computation have two attributes: the type and placement mentioned above;
  • @tff. tf_ The type signatures of the function modified by calculation has only one type attribute;
  • TF operation must be in @ TFF tf_ Implemented in computing, and then integrated into @ TFF federated_ In calculation.

For example:

def add_half(x):
    return tf.add(x, 0.5)
'(float32 -> float32)'
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
    return tff.federated_mean(client_temperatures)
'({float32}@CLIENTS -> float32@SERVER)'

This indicates that get_ average_ The input of the function temperature () is {float32}@CLIENTS, the output is {float32}@SERVER, and the latter is called places.

In addition, pass in @ TFF tf_ The calculation only needs the data type type, which is passed in @ TFF federated_ Federation types are required for computation, i.e. type and placement.

Code description

1. Import required libraries

import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff


import nest_asyncio

2. Import dataset

emnist datasets can be found directly in tensorflow_ Downloaded from the federated library.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

3. Processing data sets

There are two things we need to do:
Step 1: emnist_ The data in the client that needs to be used in the train Dataset is transformed into a Dataset format that can be trained;
Step 2: reshape the shapes of samples and label s into the shapes that need to be input by the model.

NUM_CLIENTS = 10  # 10 clients
BATCH_SIZE = 20 # Batch size is 20

# Defines the preprocessing function used to reshape the data shape
def preprocess(dataset):

    def batch_format_fn(element):
        """Flatten a batch of EMNIST data and return a (features, label) tuple."""
        return (tf.reshape(element['pixels'], [-1, 784]), 
                tf.reshape(element['label'], [-1, 1]))

    return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False) # Randomly select 10 client data in the original data

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]  # Preprocess these selected data

4. Create model

There are also two steps to be done:
Step 1: use TF Keras creates the model;
Step 2: use TFF learning. from_ keras_ Model transforms the newly created model into a model that can be used by federated learning.

def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.Dense(10, kernel_initializer='zeros'),
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(

5. Define initialization model functions

Here we need to define initialize_fn() to initialize the model parameters of the server.

# initialize_fn() function
def server_init():
    model = model_fn()
    return model.trainable_variables

def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER)  # Returns a federated value, which sets model trainable_ Put variables on the server side

6. Define iteration function

Next is defined here_ FN () to update the model parameters of the server through each training. Specifically, this process consists of four steps:
Step 1: transfer the model weight of the server to the client;
Step 2: each client trains and updates the weight separately;
Step 3: calculate the average weight of all clients;
Step 4: upload the calculated mean value to the server as the weight of the server model.

6.1 client update weight function

def client_update(model, dataset, server_weights, client_optimizer):
    """Performs training (using the server model weights) on the client's dataset."""
    # Initialize client model with server model weight
    client_weights = model.trainable_variables
    tf.nest.map_structure(lambda x, y: x.assign(y),
                          client_weights, server_weights)

    # Update client model
    for batch in dataset:
        with tf.GradientTape() as tape:
        outputs = model.forward_pass(batch) # Forward propagation

        # Gradient calculation
        grads = tape.gradient(outputs.loss, client_weights)
        grads_and_vars = zip(grads, client_weights)
        client_optimizer.apply_gradients(grads_and_vars) # Update weight

    return client_weights

6.2 server update weight function

def server_update(model, mean_client_weights):
    """Updates the server model weights as the average of the client model weights."""
    model_weights = model.trainable_variables
    # Assign the average value of the client model to the server model
    tf.nest.map_structure(lambda x, y: x.assign(y),
                          model_weights, mean_client_weights)
    return model_weights

6.3 use TFF tf_ Computation modifies the above two update functions

You need to get the data set type and model weight type first:

# The corresponding type is required, and the type of server weight can be extracted directly from our model
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

# Use the previously defined server_init function directly extracts the weight type of the model
model_weights_type = server_init.type_signature.result
# Use TFF. For client update function tf_ Computation modification
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
    model = model_fn()
    client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    return client_update(model, tf_dataset, server_weights, client_optimizer)

# Use the server update function TFF tf_ Computation modification
def server_update_fn(mean_client_weights):
    model = model_fn()
    return server_update(model, mean_client_weights)

6.4 use TFF federated_ Calculation modifier next_fn() function

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)  # Server model parameter Federation type
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)  # Dataset Federation type

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
    # Broadcast the server weights to the clients.
    server_weights_at_client = tff.federated_broadcast(server_weights)

    # Each client computes their updated weights.
    client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

    # The server averages these updates.
    mean_client_weights = tff.federated_mean(client_weights)

    # The server updates its model.
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights

7. Encapsulate the algorithm into TFF templates. In iterative process

federated_algorithm = tff.templates.IterativeProcess(

Here we can see the type signatures of the two functions we passed in:

'( -> <float32[784,10],float32[10]>@SERVER)'

The results show that federated_algorithm.initialize is a function without parameters and returns a single-layer model.

'(<<float32[784,10],float32[10]>@SERVER,{<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

The results show that federated_ algorithm. The next input is a server model and client data, and an updated server model is returned.

8. Processing test set data

The method here is the same as that of processing training set data.

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000)
central_emnist_test = preprocess(central_emnist_test)

9. Test set accuracy

9.1 test function

def evaluate(server_state):
    keras_model = create_keras_model()

9.2 model initialization

server_state = federated_algorithm.initialize()

# No training, direct test
50/50 [==============================] - 1s 17ms/step - loss: 2.3026 - sparse_categorical_accuracy: 0.1130

9.3 model training

# Training 15 times
for round in range(15):
    server_state =, federated_train_data)

50/50 [==============================] - 1s 15ms/step - loss: 2.1630 - sparse_categorical_accuracy: 0.2570

