Swiftorial Logo
Home
Swift Lessons
Tutorials
Learn More
Career
Resources

Federated Learning System

Introduction to the Federated Learning Architecture

The federated learning architecture enables decentralized model training across multiple Client Devices (e.g., mobile phones, IoT devices) while preserving data privacy. Local model updates are computed on-device using frameworks like TensorFlow Federated or PySyft, then aggregated by a Central Server using secure aggregation protocols (e.g., FedAvg). The system integrates Kafka for streaming updates, a Model Registry for versioning, and a Database for metadata. Security is ensured with TLS, Differential Privacy, and role-based access control (RBAC). Redis caches intermediate results, and Prometheus with Grafana provides observability, ensuring scalability, privacy, and robustness.

Federated learning ensures data remains on client devices, with only model updates shared, enhancing privacy and enabling decentralized training.

High-Level System Diagram

The diagram illustrates the federated learning pipeline: Client Devices (e.g., mobile, IoT) perform Local Training on user data, generating model updates. These updates are sent via an API Gateway to the Central Server, which uses Secure Aggregation to combine them into a global model. The global model is stored in a Model Registry and distributed back to clients. Kafka streams updates and metadata, while a Database stores training metadata (e.g., client participation logs). Redis caches model updates for efficiency, and Prometheus monitors system health. Arrows are color-coded: yellow (dashed) for client flows, orange-red for server flows, green (dashed) for data/cache flows, blue (dotted) for model update flows, and purple for monitoring.

graph TD A[Mobile Client] -->|HTTP Request| B[API Gateway] C[IoT Client] -->|HTTP Request| B A -->|Local Training| D[Local Training] C -->|Local Training| D D -->|Model Updates| B B -->|Routes| E[Central Server] E -->|Aggregates| F[Secure Aggregation] F -->|Stores| G[(Model Registry)] E -->|Distributes| B B -->|Updates| A B -->|Updates| C E -->|Streams| H[Kafka] H -->|Stores| I[(Database)] E -->|Cache| J[(Cache)] F -->|Cache| J E -->|Metrics| K[(Monitoring)] B -->|Metrics| K subgraph Clients A C D end subgraph Server E F G end subgraph Data Pipeline H I J end subgraph Monitoring K end classDef gateway fill:#ff6f61,stroke:#ff6f61,stroke-width:2px,rx:10,ry:10; classDef service fill:#405de6,stroke:#405de6,stroke-width:2px,rx:5,ry:5; classDef storage fill:#2ecc71,stroke:#2ecc71,stroke-width:2px; classDef monitoring fill:#9b59b6,stroke:#9b59b6,stroke-width:2px; class B gateway; class E,F,D service; class G,I,J storage; class H monitoring; linkStyle 0,1,4 stroke:#ffeb3b,stroke-width:2.5px,stroke-dasharray:6,6 linkStyle 2,3,5,6,7,8 stroke:#ff6f61,stroke-width:2.5px linkStyle 9,10 stroke:#405de6,stroke-width:2.5px,stroke-dasharray:4,4 linkStyle 11,12,13 stroke:#2ecc71,stroke-width:2.5px,stroke-dasharray:5,5 linkStyle 14,15 stroke:#9b59b6,stroke-width:2.5px
The Central Server securely aggregates model updates while ensuring client data privacy through differential privacy and encrypted communication.

Key Components

The core components of the federated learning architecture include:

  • Client Devices (Mobile, IoT): Perform local training on private user data using TensorFlow Federated or PySyft.
  • Local Training: Computes model updates on-device without sharing raw data.
  • API Gateway: Routes model updates and enforces rate limiting (e.g., Kong).
  • Central Server: Orchestrates aggregation and model distribution.
  • Secure Aggregation: Combines client updates using FedAvg or secure multi-party computation.
  • Model Registry: Stores global models with versioning (e.g., MLflow).
  • Kafka: Streams model updates and metadata for scalability.
  • Database: Stores training metadata, such as client participation logs (e.g., MongoDB).
  • Cache: Redis for low-latency access to model updates and metadata.
  • Monitoring: Prometheus and Grafana for system and training performance metrics.
  • Security: TLS, differential privacy, and RBAC for secure and private operations.

Benefits of the Architecture

  • Data Privacy: Local training ensures user data never leaves devices.
  • Scalability: Supports millions of clients with Kafka and distributed aggregation.
  • Resilience: Decentralized training and caching reduce single-point failures.
  • Performance: Caching and optimized aggregation minimize latency.
  • Flexibility: Compatible with various frameworks (TensorFlow, PyTorch) and aggregation protocols.
  • Observability: Comprehensive monitoring of training progress and system health.
  • Security: TLS, differential privacy, and RBAC protect model updates and metadata.

Implementation Considerations

Building a robust federated learning system requires meticulous planning to ensure privacy, scalability, and efficiency:

  • Client-Side Training: Optimize TensorFlow Federated or PySyft for resource-constrained devices (e.g., mobile GPUs).
  • API Gateway: Configure Kong with JWT validation and rate limiting to handle client traffic.
  • Secure Aggregation: Implement FedAvg with differential privacy or secure multi-party computation for privacy-preserving aggregation.
  • Model Registry: Use MLflow to version global models and track training metadata.
  • Kafka Configuration: Set up topic partitioning for scalable streaming of model updates.
  • Database: Use MongoDB with encrypted connections and indexed queries for metadata storage.
  • Cache Strategy: Configure Redis with TTLs for model updates and client metadata to reduce latency.
  • Differential Privacy: Apply noise to model updates (e.g., Gaussian noise) to ensure privacy guarantees.
  • Monitoring: Deploy Prometheus for training metrics (e.g., loss convergence) and ELK for logs.
  • Security: Enforce TLS for all communications and RBAC for server access control.
  • Client Selection: Implement dynamic client selection strategies to balance participation and resource usage.
  • Model Compression: Use quantization or pruning to reduce model update size for efficient transmission.
Regular privacy audits, model update validation, and client performance monitoring are critical for a secure and efficient system.

Example Configuration: Kafka for Model Update Streaming

Below is a Kafka configuration for streaming model updates from clients:

# Create a topic for model updates
kafka-topics.sh --create \
  --bootstrap-server kafka:9092 \
  --partitions 8 \
  --replication-factor 3 \
  --topic model-updates

# Configure producer for client updates
kafka-console-producer.sh \
  --bootstrap-server kafka:9092 \
  --topic model-updates \
  --property "parse.key=true" \
  --property "key.separator=,"

# Configure consumer for central server
kafka-console-consumer.sh \
  --bootstrap-server kafka:9092 \
  --topic model-updates \
  --from-beginning \
  --property print.key=true \
  --property key.separator=,
                

Example Configuration: Central Server with Secure Aggregation

Below is a Python-based Central Server implementation with secure aggregation and RBAC:

from flask import Flask, request, jsonify
import jwt
import redis
from pymongo import MongoClient
import tensorflow_federated as tff
import numpy as np
import os
import requests

app = Flask(__name__)
JWT_SECRET = os.getenv('JWT_SECRET', 'your-secret-key')
REDIS_HOST = 'redis://redis-host:6379'
MONGO_URI = 'mongodb://mongo:27017'
MODEL_REGISTRY_URL = 'http://mlflow:5000'

# Initialize clients
redis_client = redis.Redis.from_url(REDIS_HOST)
mongo_client = MongoClient(MONGO_URI)
db = mongo_client['federated_learning']

def check_rbac(required_role):
    def decorator(f):
        def wrapper(*args, **kwargs):
            auth_header = request.headers.get('Authorization')
            if not auth_header or not auth_header.startswith('Bearer '):
                return jsonify({'error': 'Unauthorized'}), 401
            token = auth_header.split(' ')[1]
            try:
                decoded = jwt.decode(token, JWT_SECRET, algorithms=['HS256'])
                if decoded.get('role') != required_role:
                    return jsonify({'error': 'Insufficient permissions'}), 403
                return f(*args, **kwargs)
            except jwt.InvalidTokenError:
                return jsonify({'error': 'Invalid token'}), 403
        return wrapper
    return decorator

def secure_aggregation(updates):
    # Simplified FedAvg implementation with differential privacy
    weights = [u['weights'] for u in updates]
    avg_weights = np.mean(weights, axis=0)
    noise = np.random.normal(0, 0.01, avg_weights.shape)  # Differential privacy noise
    return avg_weights + noise

@app.route('/aggregate', methods=['POST'])
@check_rbac('federated')
def aggregate():
    data = request.json
    client_updates = data['updates']
    session_id = data['session_id']

    # Check cache for recent aggregation
    cache_key = f'aggregation:{session_id}'
    cached = redis_client.get(cache_key)
    if cached:
        return jsonify({'global_model': cached.decode('utf-8')})

    # Perform secure aggregation
    global_weights = secure_aggregation(client_updates)

    # Save to Model Registry (mocked)
    model_id = f'model_{session_id}'
    requests.post(f'{MODEL_REGISTRY_URL}/models', json={
        'model_id': model_id,
        'weights': global_weights.tolist()
    })

    # Cache and store metadata
    redis_client.setex(cache_key, 3600, str(global_weights.tolist()))
    db['training_metadata'].update_one(
        {'session_id': session_id},
        {'$set': {
            'global_weights': global_weights.tolist(),
            'client_count': len(client_updates),
            'updated_at': datetime.now()
        }},
        upsert=True
    )

    return jsonify({'global_model': global_weights.tolist()})

@app.route('/model', methods=['GET'])
@check_rbac('federated')
def distribute_model():
    session_id = request.args.get('session_id')
    cached = redis_client.get(f'aggregation:{session_id}')
    if cached:
        return jsonify({'global_model': cached.decode('utf-8')})
    model = db['training_metadata'].find_one({'session_id': session_id})
    if model:
        return jsonify({'global_model': model['global_weights']})
    return jsonify({'error': 'Model not found'}), 404

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, ssl_context=('server-cert.pem', 'server-key.pem'))
                

Example Configuration: Client-Side Local Training

Below is a Python-based client-side implementation for local training using TensorFlow Federated:

import tensorflow as tf
import tensorflow_federated as tff
import requests
import os

CENTRAL_SERVER_URL = 'https://central-server:5000/aggregate'
JWT_TOKEN = os.getenv('JWT_TOKEN', 'your-jwt-token')

def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

def local_training(dataset):
    model = create_model()
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
    model.fit(dataset, epochs=5)
    return model.get_weights()

def main():
    # Load local dataset (mocked)
    dataset = load_local_data()  # Replace with actual data loading
    client_id = os.getenv('CLIENT_ID', 'client_1')
    session_id = 'training_session_1'

    # Perform local training
    local_weights = local_training(dataset)

    # Send updates to central server
    response = requests.post(
        CENTRAL_SERVER_URL,
        json={'updates': [{'client_id': client_id, 'weights': local_weights}], 'session_id': session_id},
        headers={'Authorization': f'Bearer {JWT_TOKEN}'}
    )

    if response.status_code == 200:
        print('Model updates sent successfully')
    else:
        print('Failed to send updates:', response.json())

if cute = tff.simulation.ClientData.from_clients_and_fn(
        client_ids=[client_id],
        create_tf_dataset_for_client_fn=lambda _: dataset
    )
    tff.backends.native.set_local_execution_context()
    main()