Project introduction
Federated learning algorithm is used to classify emnist data sets.
First of all, we should clarify a knowledge point, which is the modifier @ TFF tf_ Calculation and @ TFF federated_ The difference between computation:
- tensorflow_ The federated (TFF) operation deals with federated values;
- Each Federation value has a federation type, which includes type (tf.float32) and placement (tff.CLIENTS);
- Federated values can be moved (client – > server), which must be through @ TFF federated_ Computation and a federation type are completed;
- @tff. federated_ The type signatures of the function modified by computation have two attributes: the type and placement mentioned above;
- @tff. tf_ The type signatures of the function modified by calculation has only one type attribute;
- TF operation must be in @ TFF tf_ Implemented in computing, and then integrated into @ TFF federated_ In calculation.
For example:
@tff.tf_computation(tf.float32) def add_half(x): return tf.add(x, 0.5) str(add_half.type_signature)
'(float32 -> float32)'
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def get_average_temperature(client_temperatures): return tff.federated_mean(client_temperatures) str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'
This indicates that get_ average_ The input of the function temperature () is {float32}@CLIENTS, the output is {float32}@SERVER, and the latter is called places.
In addition, pass in @ TFF tf_ The calculation only needs the data type type, which is passed in @ TFF federated_ Federation types are required for computation, i.e. type and placement.
Code description
1. Import required libraries
import collections import attr import functools import numpy as np import tensorflow as tf import tensorflow_federated as tff np.random.seed(0) import nest_asyncio nest_asyncio.apply()
2. Import dataset
emnist datasets can be found directly in tensorflow_ Downloaded from the federated library.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
3. Processing data sets
There are two things we need to do:
Step 1: emnist_ The data in the client that needs to be used in the train Dataset is transformed into a Dataset format that can be trained;
Step 2: reshape the shapes of samples and label s into the shapes that need to be input by the model.
NUM_CLIENTS = 10 # 10 clients BATCH_SIZE = 20 # Batch size is 20 # Defines the preprocessing function used to reshape the data shape def preprocess(dataset): def batch_format_fn(element): """Flatten a batch of EMNIST data and return a (features, label) tuple.""" return (tf.reshape(element['pixels'], [-1, 784]), tf.reshape(element['label'], [-1, 1])) return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False) # Randomly select 10 client data in the original data federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x)) for x in client_ids ] # Preprocess these selected data
4. Create model
There are also two steps to be done:
Step 1: use TF Keras creates the model;
Step 2: use TFF learning. from_ keras_ Model transforms the newly created model into a model that can be used by federated learning.
def create_keras_model(): return tf.keras.models.Sequential([ tf.keras.layers.Input(shape=(784,)), tf.keras.layers.Dense(10, kernel_initializer='zeros'), tf.keras.layers.Softmax(), ])
def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=federated_train_data[0].element_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
5. Define initialization model functions
Here we need to define initialize_fn() to initialize the model parameters of the server.
# initialize_fn() function @tff.tf_computation def server_init(): model = model_fn() return model.trainable_variables @tff.federated_computation def initialize_fn(): return tff.federated_value(server_init(), tff.SERVER) # Returns a federated value, which sets model trainable_ Put variables on the server side
6. Define iteration function
Next is defined here_ FN () to update the model parameters of the server through each training. Specifically, this process consists of four steps:
Step 1: transfer the model weight of the server to the client;
Step 2: each client trains and updates the weight separately;
Step 3: calculate the average weight of all clients;
Step 4: upload the calculated mean value to the server as the weight of the server model.
6.1 client update weight function
@tf.function def client_update(model, dataset, server_weights, client_optimizer): """Performs training (using the server model weights) on the client's dataset.""" # Initialize client model with server model weight client_weights = model.trainable_variables tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights) # Update client model for batch in dataset: with tf.GradientTape() as tape: outputs = model.forward_pass(batch) # Forward propagation # Gradient calculation grads = tape.gradient(outputs.loss, client_weights) grads_and_vars = zip(grads, client_weights) client_optimizer.apply_gradients(grads_and_vars) # Update weight return client_weights
6.2 server update weight function
@tf.function def server_update(model, mean_client_weights): """Updates the server model weights as the average of the client model weights.""" model_weights = model.trainable_variables # Assign the average value of the client model to the server model tf.nest.map_structure(lambda x, y: x.assign(y), model_weights, mean_client_weights) return model_weights
6.3 use TFF tf_ Computation modifies the above two update functions
You need to get the data set type and model weight type first:
# The corresponding type is required, and the type of server weight can be extracted directly from our model whimsy_model = model_fn() tf_dataset_type = tff.SequenceType(whimsy_model.input_spec) str(tf_dataset_type) # Use the previously defined server_init function directly extracts the weight type of the model model_weights_type = server_init.type_signature.result str(model_weights_type)
# Use TFF. For client update function tf_ Computation modification @tff.tf_computation(tf_dataset_type, model_weights_type) def client_update_fn(tf_dataset, server_weights): model = model_fn() client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) return client_update(model, tf_dataset, server_weights, client_optimizer) # Use the server update function TFF tf_ Computation modification @tff.tf_computation(model_weights_type) def server_update_fn(mean_client_weights): model = model_fn() return server_update(model, mean_client_weights)
6.4 use TFF federated_ Calculation modifier next_fn() function
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER) # Server model parameter Federation type federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) # Dataset Federation type @tff.federated_computation(federated_server_type, federated_dataset_type) def next_fn(server_weights, federated_dataset): # Broadcast the server weights to the clients. server_weights_at_client = tff.federated_broadcast(server_weights) # Each client computes their updated weights. client_weights = tff.federated_map( client_update_fn, (federated_dataset, server_weights_at_client)) # The server averages these updates. mean_client_weights = tff.federated_mean(client_weights) # The server updates its model. server_weights = tff.federated_map(server_update_fn, mean_client_weights) return server_weights
7. Encapsulate the algorithm into TFF templates. In iterative process
federated_algorithm = tff.templates.IterativeProcess( initialize_fn=initialize_fn, next_fn=next_fn )
Here we can see the type signatures of the two functions we passed in:
str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'
The results show that federated_algorithm.initialize is a function without parameters and returns a single-layer model.
str(federated_algorithm.next.type_signature)
'(<<float32[784,10],float32[10]>@SERVER,{<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'
The results show that federated_ algorithm. The next input is a server model and client data, and an updated server model is returned.
8. Processing test set data
The method here is the same as that of processing training set data.
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000) central_emnist_test = preprocess(central_emnist_test)
9. Test set accuracy
9.1 test function
def evaluate(server_state): keras_model = create_keras_model() keras_model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] ) keras_model.set_weights(server_state) keras_model.evaluate(central_emnist_test)
9.2 model initialization
server_state = federated_algorithm.initialize() # No training, direct test evaluate(server_state)
50/50 [==============================] - 1s 17ms/step - loss: 2.3026 - sparse_categorical_accuracy: 0.1130
9.3 model training
# Training 15 times for round in range(15): server_state = federated_algorithm.next(server_state, federated_train_data) evaluate(server_state)
50/50 [==============================] - 1s 15ms/step - loss: 2.1630 - sparse_categorical_accuracy: 0.2570