[docs]classServerService(server_grpc.ServerServiceServicer):"""Remote gRPC server service. Args: server (:obj:`BaseServer`): Federated learning server instance. """def__init__(self,server):self._base=serverself._clients_per_round=0self._train_client_count=0self._uploaded_models={}self._uploaded_weights={}self._uploaded_metrics=[]self._uploaded_content1={}self._uploaded_content2={}self._uploaded_content1_name="content1"self._uploaded_content2_name="content2"self._test_client_count=0self._test_metrics={}self._test_sizes=[]
[docs]defRun(self,request,context):"""Trigger federated learning process."""response=server_pb.RunResponse(status=common_pb.Status(code=common_pb.SC_OK),)ifself._base.is_training:response=server_pb.RunResponse(status=common_pb.Status(code=common_pb.SC_ALREADY_EXISTS,message="Training in progress, please stop current training or wait for completion",),)else:model=codec.unmarshal(request.model)self._base.start_remote_training(model,request.clients)returnresponse
[docs]defUpload(self,request,context):"""Handle upload from clients."""# TODO: put train and test logic in a separate thread and add thread lock to ensure atomicity.t=threading.Thread(target=self._handle_upload,args=[request,context])t.start()response=server_pb.UploadResponse(status=common_pb.Status(code=common_pb.SC_OK),)returnresponse
def_handle_upload(self,request,context):# if not self._base.upload_event.is_set():data=codec.unmarshal(request.content.data)data_size=request.content.data_sizeclient_metric=metric.ClientMetric.from_proto(request.content.metric)clients_per_round=self._base.conf.server.clients_per_roundnum_of_clients=self._base.num_of_clients()ifnum_of_clients<clients_per_round:# TODO: use a more appropriate way to handle this situationlogger.warning("Available number of clients {} is smaller than clients per round {}".format(num_of_clients,clients_per_round))self._clients_per_round=num_of_clientselse:self._clients_per_round=clients_per_roundifrequest.content.type==common_pb.DATA_TYPE_PARAMS:self._handle_upload_train(request.client_id,data,data_size,client_metric)elifrequest.content.type==common_pb.DATA_TYPE_PERFORMANCE:self._handle_upload_test(data,data_size,client_metric)elifrequest.content.type==common_pb.DATA_TYPE_FEATURE:# 2: feature_labelself._handle_upload_train_customize_content(request.client_id,data,data_size,client_metric)def_handle_upload_train(self,client_id,data,data_size,client_metric):model=self._base.decompression(data)self._uploaded_models[client_id]=modelself._uploaded_weights[client_id]=data_sizeself._uploaded_metrics.append(client_metric)self._train_client_count+=1self._trigger_aggregate_train()def_handle_upload_test(self,data,data_size,client_metric):forkey,valueindata.performance.items():ifkeyinself._test_metrics:self._test_metrics[key].append(value)else:self._test_metrics[key]=[value]self._test_sizes.append(data_size)self._uploaded_metrics.append(client_metric)self._test_client_count+=1self._trigger_aggregate_test()def_handle_upload_train_customize_content(self,client_id,data,data_size,client_metric):customize_dict=self._base.decompression(data)self._uploaded_content1[client_id]=customize_dict["content"][0]self._uploaded_content2[client_id]=customize_dict["content"][1]self._uploaded_content1_name=customize_dict["name"][0]self._uploaded_content2_name=customize_dict["name"][1]self._train_client_count+=1self._trigger_pack_customize_content_train()def_trigger_aggregate_train(self):logger.info("train_client_count: {}/{}".format(self._train_client_count,self._clients_per_round))ifself._train_client_count==self._clients_per_round:self._base.set_client_uploads_train(self._uploaded_models,self._uploaded_weights,self._uploaded_metrics)self._train_client_count=0self._reset_train_cache()withself._base.condition():self._base.notify_all()def_trigger_aggregate_test(self):# TODO: determine the testing clients not only by the selected number of clientsifself._test_client_count==self._clients_per_round:self._base.set_client_uploads_test(self._test_metrics,self._test_sizes,self._uploaded_metrics)self._test_client_count=0self._reset_test_cache()withself._base.condition():self._base.notify_all()def_trigger_pack_customize_content_train(self):ifself._train_client_count==self._clients_per_round:self._base.set_client_uploads_customize_content(self._uploaded_content1,self._uploaded_content2,self._uploaded_content1_name,self._uploaded_content2_name)self._train_client_count=0self._reset_customize_content_cache()withself._base.condition():self._base.notify_all()def_reset_train_cache(self):self._uploaded_models={}self._uploaded_weights={}self._uploaded_metrics=[]def_reset_test_cache(self):self._test_metrics={}self._test_sizes=[]self._uploaded_metrics=[]def_reset_customize_content_cache(self):self._uploaded_content1={}self._uploaded_content2={}self._uploaded_content1_name="content1"self._uploaded_content2_name="content2"