Source code for coala.server.service
import logging
import threading
from coala.pb import server_service_pb2_grpc as server_grpc, server_service_pb2 as server_pb, common_pb2 as common_pb
from coala.protocol import codec
from coala.tracking import metric
logger = logging.getLogger(__name__)
[docs]class ServerService(server_grpc.ServerServiceServicer):
"""Remote gRPC server service.
Args:
server (:obj:`BaseServer`): Federated learning server instance.
"""
def __init__(self, server):
self._base = server
self._clients_per_round = 0
self._train_client_count = 0
self._uploaded_models = {}
self._uploaded_weights = {}
self._uploaded_metrics = []
self._uploaded_content1 = {}
self._uploaded_content2 = {}
self._uploaded_content1_name = "content1"
self._uploaded_content2_name = "content2"
self._test_client_count = 0
self._test_metrics = {}
self._test_sizes = []
[docs] def Run(self, request, context):
"""Trigger federated learning process."""
response = server_pb.RunResponse(
status=common_pb.Status(code=common_pb.SC_OK),
)
if self._base.is_training:
response = server_pb.RunResponse(
status=common_pb.Status(
code=common_pb.SC_ALREADY_EXISTS,
message="Training in progress, please stop current training or wait for completion",
),
)
else:
model = codec.unmarshal(request.model)
self._base.start_remote_training(model, request.clients)
return response
[docs] def Stop(self, request, context):
"""Stop federated learning process."""
response = server_pb.StopResponse(
status=common_pb.Status(code=common_pb.SC_OK),
)
if self._base.is_training():
self._base.stop()
else:
response = server_pb.RunResponse(
status=common_pb.Status(
code=common_pb.SC_NOT_FOUND,
message="No existing training",
),
)
return response
[docs] def Upload(self, request, context):
"""Handle upload from clients."""
# TODO: put train and test logic in a separate thread and add thread lock to ensure atomicity.
t = threading.Thread(target=self._handle_upload, args=[request, context])
t.start()
response = server_pb.UploadResponse(
status=common_pb.Status(code=common_pb.SC_OK),
)
return response
def _handle_upload(self, request, context):
# if not self._base.upload_event.is_set():
data = codec.unmarshal(request.content.data)
data_size = request.content.data_size
client_metric = metric.ClientMetric.from_proto(request.content.metric)
clients_per_round = self._base.conf.server.clients_per_round
num_of_clients = self._base.num_of_clients()
if num_of_clients < clients_per_round:
# TODO: use a more appropriate way to handle this situation
logger.warning(
"Available number of clients {} is smaller than clients per round {}".format(num_of_clients,
clients_per_round))
self._clients_per_round = num_of_clients
else:
self._clients_per_round = clients_per_round
if request.content.type == common_pb.DATA_TYPE_PARAMS:
self._handle_upload_train(request.client_id, data, data_size, client_metric)
elif request.content.type == common_pb.DATA_TYPE_PERFORMANCE:
self._handle_upload_test(data, data_size, client_metric)
elif request.content.type == common_pb.DATA_TYPE_FEATURE: # 2: feature_label
self._handle_upload_train_customize_content(request.client_id, data, data_size, client_metric)
def _handle_upload_train(self, client_id, data, data_size, client_metric):
model = self._base.decompression(data)
self._uploaded_models[client_id] = model
self._uploaded_weights[client_id] = data_size
self._uploaded_metrics.append(client_metric)
self._train_client_count += 1
self._trigger_aggregate_train()
def _handle_upload_test(self, data, data_size, client_metric):
for key, value in data.performance.items():
if key in self._test_metrics:
self._test_metrics[key].append(value)
else:
self._test_metrics[key] = [value]
self._test_sizes.append(data_size)
self._uploaded_metrics.append(client_metric)
self._test_client_count += 1
self._trigger_aggregate_test()
def _handle_upload_train_customize_content(self, client_id, data, data_size, client_metric):
customize_dict = self._base.decompression(data)
self._uploaded_content1[client_id] = customize_dict["content"][0]
self._uploaded_content2[client_id] = customize_dict["content"][1]
self._uploaded_content1_name = customize_dict["name"][0]
self._uploaded_content2_name = customize_dict["name"][1]
self._train_client_count += 1
self._trigger_pack_customize_content_train()
def _trigger_aggregate_train(self):
logger.info("train_client_count: {}/{}".format(self._train_client_count, self._clients_per_round))
if self._train_client_count == self._clients_per_round:
self._base.set_client_uploads_train(self._uploaded_models, self._uploaded_weights, self._uploaded_metrics)
self._train_client_count = 0
self._reset_train_cache()
with self._base.condition():
self._base.notify_all()
def _trigger_aggregate_test(self):
# TODO: determine the testing clients not only by the selected number of clients
if self._test_client_count == self._clients_per_round:
self._base.set_client_uploads_test(self._test_metrics, self._test_sizes, self._uploaded_metrics)
self._test_client_count = 0
self._reset_test_cache()
with self._base.condition():
self._base.notify_all()
def _trigger_pack_customize_content_train(self):
if self._train_client_count == self._clients_per_round:
self._base.set_client_uploads_customize_content(self._uploaded_content1, self._uploaded_content2, self._uploaded_content1_name, self._uploaded_content2_name)
self._train_client_count = 0
self._reset_customize_content_cache()
with self._base.condition():
self._base.notify_all()
def _reset_train_cache(self):
self._uploaded_models = {}
self._uploaded_weights = {}
self._uploaded_metrics = []
def _reset_test_cache(self):
self._test_metrics = {}
self._test_sizes = []
self._uploaded_metrics = []
def _reset_customize_content_cache(self):
self._uploaded_content1 = {}
self._uploaded_content2 = {}
self._uploaded_content1_name = "content1"
self._uploaded_content2_name = "content2"