Shortcuts

Source code for coala.server.service

import logging
import threading

from coala.pb import server_service_pb2_grpc as server_grpc, server_service_pb2 as server_pb, common_pb2 as common_pb
from coala.protocol import codec
from coala.tracking import metric

logger = logging.getLogger(__name__)


[docs]class ServerService(server_grpc.ServerServiceServicer): """Remote gRPC server service. Args: server (:obj:`BaseServer`): Federated learning server instance. """ def __init__(self, server): self._base = server self._clients_per_round = 0 self._train_client_count = 0 self._uploaded_models = {} self._uploaded_weights = {} self._uploaded_metrics = [] self._uploaded_content1 = {} self._uploaded_content2 = {} self._uploaded_content1_name = "content1" self._uploaded_content2_name = "content2" self._test_client_count = 0 self._test_metrics = {} self._test_sizes = []
[docs] def Run(self, request, context): """Trigger federated learning process.""" response = server_pb.RunResponse( status=common_pb.Status(code=common_pb.SC_OK), ) if self._base.is_training: response = server_pb.RunResponse( status=common_pb.Status( code=common_pb.SC_ALREADY_EXISTS, message="Training in progress, please stop current training or wait for completion", ), ) else: model = codec.unmarshal(request.model) self._base.start_remote_training(model, request.clients) return response
[docs] def Stop(self, request, context): """Stop federated learning process.""" response = server_pb.StopResponse( status=common_pb.Status(code=common_pb.SC_OK), ) if self._base.is_training(): self._base.stop() else: response = server_pb.RunResponse( status=common_pb.Status( code=common_pb.SC_NOT_FOUND, message="No existing training", ), ) return response
[docs] def Upload(self, request, context): """Handle upload from clients.""" # TODO: put train and test logic in a separate thread and add thread lock to ensure atomicity. t = threading.Thread(target=self._handle_upload, args=[request, context]) t.start() response = server_pb.UploadResponse( status=common_pb.Status(code=common_pb.SC_OK), ) return response
def _handle_upload(self, request, context): # if not self._base.upload_event.is_set(): data = codec.unmarshal(request.content.data) data_size = request.content.data_size client_metric = metric.ClientMetric.from_proto(request.content.metric) clients_per_round = self._base.conf.server.clients_per_round num_of_clients = self._base.num_of_clients() if num_of_clients < clients_per_round: # TODO: use a more appropriate way to handle this situation logger.warning( "Available number of clients {} is smaller than clients per round {}".format(num_of_clients, clients_per_round)) self._clients_per_round = num_of_clients else: self._clients_per_round = clients_per_round if request.content.type == common_pb.DATA_TYPE_PARAMS: self._handle_upload_train(request.client_id, data, data_size, client_metric) elif request.content.type == common_pb.DATA_TYPE_PERFORMANCE: self._handle_upload_test(data, data_size, client_metric) elif request.content.type == common_pb.DATA_TYPE_FEATURE: # 2: feature_label self._handle_upload_train_customize_content(request.client_id, data, data_size, client_metric) def _handle_upload_train(self, client_id, data, data_size, client_metric): model = self._base.decompression(data) self._uploaded_models[client_id] = model self._uploaded_weights[client_id] = data_size self._uploaded_metrics.append(client_metric) self._train_client_count += 1 self._trigger_aggregate_train() def _handle_upload_test(self, data, data_size, client_metric): for key, value in data.performance.items(): if key in self._test_metrics: self._test_metrics[key].append(value) else: self._test_metrics[key] = [value] self._test_sizes.append(data_size) self._uploaded_metrics.append(client_metric) self._test_client_count += 1 self._trigger_aggregate_test() def _handle_upload_train_customize_content(self, client_id, data, data_size, client_metric): customize_dict = self._base.decompression(data) self._uploaded_content1[client_id] = customize_dict["content"][0] self._uploaded_content2[client_id] = customize_dict["content"][1] self._uploaded_content1_name = customize_dict["name"][0] self._uploaded_content2_name = customize_dict["name"][1] self._train_client_count += 1 self._trigger_pack_customize_content_train() def _trigger_aggregate_train(self): logger.info("train_client_count: {}/{}".format(self._train_client_count, self._clients_per_round)) if self._train_client_count == self._clients_per_round: self._base.set_client_uploads_train(self._uploaded_models, self._uploaded_weights, self._uploaded_metrics) self._train_client_count = 0 self._reset_train_cache() with self._base.condition(): self._base.notify_all() def _trigger_aggregate_test(self): # TODO: determine the testing clients not only by the selected number of clients if self._test_client_count == self._clients_per_round: self._base.set_client_uploads_test(self._test_metrics, self._test_sizes, self._uploaded_metrics) self._test_client_count = 0 self._reset_test_cache() with self._base.condition(): self._base.notify_all() def _trigger_pack_customize_content_train(self): if self._train_client_count == self._clients_per_round: self._base.set_client_uploads_customize_content(self._uploaded_content1, self._uploaded_content2, self._uploaded_content1_name, self._uploaded_content2_name) self._train_client_count = 0 self._reset_customize_content_cache() with self._base.condition(): self._base.notify_all() def _reset_train_cache(self): self._uploaded_models = {} self._uploaded_weights = {} self._uploaded_metrics = [] def _reset_test_cache(self): self._test_metrics = {} self._test_sizes = [] self._uploaded_metrics = [] def _reset_customize_content_cache(self): self._uploaded_content1 = {} self._uploaded_content2 = {} self._uploaded_content1_name = "content1" self._uploaded_content2_name = "content2"