importargparseimportconcurrent.futuresimportcopyimportloggingimportosimportthreadingimporttimeimportnumpyasnpimporttorchimporttorch.distributedasdistfromomegaconfimportOmegaConffromcoala.communicationimportgrpc_wrapperfromcoala.datasetsimportTEST_IN_SERVERfromcoala.distributedimportgrouping,reduce_models,reduce_models_only_params, \
reduce_value,reduce_values,reduce_weighted_values,gather_valuefromcoala.distributed.distributedimportCPU,GREEDY_GROUPINGfromcoala.pbimportclient_service_pb2asclient_pbfromcoala.pbimportcommon_pb2ascommon_pbfromcoala.protocolimportcodecfromcoala.registry.etcd_clientimportEtcdClientfromcoala.serverimportstrategiesfromcoala.server.serviceimportServerServicefromcoala.trackingimportmetricfromcoala.tracking.clientimportinit_trackingfromcoala.utils.floatimportroundinglogger=logging.getLogger(__name__)# train and test paramsMODEL="model"DATA_SIZE="data_size"METRIC="metric"CLIENT_METRICS="client_metrics"FEDERATED_AVERAGE="FedAvg"EQUAL_AVERAGE="equal"AGGREGATION_CONTENT_ALL="all"AGGREGATION_CONTENT_PARAMS="parameters"defcreate_argument_parser():"""Create argument parser with arguments/configurations for starting server service. Returns: argparse.ArgumentParser: The parser with server service arguments. """parser=argparse.ArgumentParser(description='Federated Server')parser.add_argument('--local-port',type=int,default=22999,help='Listen port of the client')parser.add_argument('--tracker-addr',type=str,default="localhost:12666",help='Address of tracking service in [IP]:[PORT] format')parser.add_argument('--is-remote',type=bool,default=False,help='Whether start as a remote server.')returnparser
[docs]classBaseServer(object):"""Default implementation of federated learning server. Args: conf (omegaconf.dictconfig.DictConfig): Configurations of COALA. test_data (:obj:`FederatedDataset`): Test dataset for centralized testing in server, optional. val_data (:obj:`FederatedDataset`): Validation dataset for centralized validation in server, optional. is_remote (bool): A flag to indicate whether start remote training. local_port (int): The port of remote server service. Override the class and functions to implement customized server. Example: >>> from coala.server import BaseServer >>> class CustomizedServer(BaseServer): >>> def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999): >>> super(CustomizedServer, self).__init__(conf, test_data, val_data, is_remote, local_port) >>> pass # more initialization of attributes. >>> >>> def aggregation(self): >>> # Implement customized aggregation method, which overwrites the default aggregation method. >>> pass """def__init__(self,conf,test_data=None,val_data=None,is_remote=False,local_port=22999):self.conf=confself.test_data=test_dataself.val_data=val_dataself.is_remote=is_remoteself.local_port=local_portself.is_training=Falseself.should_stop=Falseself.current_round=-1self.client_uploads={}self.model=Noneself.clients=Noneself.selected_clients=[]self.grouped_clients=[]self.tracker=Noneself.cumulative_times=[]# cumulative training after each testself.performance_metrics={}self.visualize=conf.server.visualizeself.client_stubs={}self._etcd=Noneself._server_metric=Noneself._round_time=Noneself._begin_train_time=None# training begin time for a roundself._start_time=None# training start time for a taskifself.conf.is_distributed:self.default_time=self.conf.resource_heterogeneous.initial_default_timeself._condition=threading.Condition()self.init_tracker()
[docs]defstart(self,model,clients):"""Start federated learning process, including training and testing. Args: model (nn.Module): The model to train. clients (list[:obj:`BaseClient`]|list[str]): Available clients. Clients are actually client grpc addresses when in remote training. """# Setupself._start_time=time.time()self._reset()self.set_model(model)self.set_clients(clients)ifself._should_track():self.tracker.create_task(self.conf.task_id,OmegaConf.to_container(self.conf))# Get initial testing accuraciesifself.conf.server.test_all:ifself._should_track():self.tracker.set_round(self.current_round)self.test()self.save_tracker()whilenotself.should_terminate():self._round_time=time.time()self.current_round+=1self.print_("\n-------- round {} --------".format(self.current_round))# Trainself.pre_train()self.train()self.post_train()# Testifself._do_every(self.conf.server.test_every,self.current_round,self.conf.server.rounds):self.pre_test()self.test()self.post_test()# Save Modelself.save_model()self.track(metric.ROUND_TIME,time.time()-self._round_time)self.save_tracker()forkey,valuesinself.performance_metrics.items():self.print_("{}: {}".format(str(key).capitalize(),rounding(values,4)))self.print_("Cumulative training time: {}".format(rounding(self.cumulative_times,2)))
[docs]defstop(self):"""Set the flag to indicate training should stop."""self.should_stop=True
[docs]defpre_train(self):"""Preprocessing before training."""pass
[docs]deftrain(self):"""Training process of federated learning."""self.print_("--- start training ---")self.selection(self.clients,self.conf.server.clients_per_round)self.grouping_for_distributed()self.compression()begin_train_time=time.time()self.distribution_to_train()self.aggregation()train_time=time.time()-begin_train_timeself.print_("Server train time: {:.2f}s".format(train_time))self.track(metric.TRAIN_TIME,train_time)self.tracking_visualization({metric.TRAIN_TIME:train_time})
[docs]defpost_train(self):"""Postprocessing after training."""pass
[docs]defpre_test(self):"""Preprocessing before testing."""pass
[docs]deftest(self):"""Testing process of federated learning."""self.print_("--- start testing ---")test_begin_time=time.time()test_results={metric.TEST_TIME:0}ifself.conf.test_mode==TEST_IN_SERVER:ifself.is_primary_server():test_results=self.test_in_server(self.conf.device)test_results[metric.TEST_TIME]=time.time()-test_begin_timeself.track_test_results(test_results)self.tracking_visualization(test_results)ifself.conf.is_distributed:dist.barrier()returnelse:test_results=self.test_in_client()test_results[metric.TEST_TIME]=time.time()-test_begin_timeself.track_test_results(test_results)self.tracking_visualization(test_results)
[docs]defpost_test(self):"""Postprocessing after testing."""pass
[docs]defshould_terminate(self):"""Check whether should stop training. Stops the training under two conditions: 1. Reach max number of training rounds 2. TODO: Accuracy higher than certain amount. Returns: bool: A flag to indicate whether should stop training. """ifself.should_stopor(self.conf.server.roundsandself.current_round+1>=self.conf.server.rounds):self.is_training=FalsereturnTruereturnFalse
[docs]deftest_in_client(self):"""Conduct testing in clients. Currently, it supports testing on the selected clients for training. TODO: Add optionals to select clients for testing. Returns: dict: Test metrics, {"test_metric": dict, "test_time": value}. """self.compression()self.distribution_to_test()returnself.aggregation_test()
[docs]deftest_in_server(self,device=CPU):"""Conduct testing in the server. Overwrite this method for different training backend, default is PyTorch. Args: device (str): The hardware device to conduct testing, either cpu or cuda devices. Returns: dict: Test metrics, {"test_metric": dict, "test_time": value}. """self.model.eval()self.model.to(device)test_loss=0correct=0loss_fn=torch.nn.CrossEntropyLoss().to(device)test_loader=self.test_data.loader(self.conf.server.batch_size,seed=self.conf.seed)withtorch.no_grad():forbatched_x,batched_yintest_loader:x=batched_x.to(device)y=batched_y.to(device)log_probs=self.model(x)loss=loss_fn(log_probs,y)_,y_pred=torch.max(log_probs,-1)correct+=y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()test_loss+=loss.item()test_data_size=self.test_data.size()test_loss/=len(test_loader)accuracy=100.00*correct/test_data_sizetest_results={metric.TEST_METRIC:{"accuracy":float(accuracy),"loss":float(test_loss)}}returntest_results
[docs]defselection(self,clients,clients_per_round):"""Select a fraction of total clients for training. Two selection strategies are implemented: 1. random selection; 2. select the first K clients. Args: clients (list[:obj:`BaseClient`]|list[str]): Available clients. clients_per_round (int): Number of clients to participate in training each round. Returns: (list[:obj:`BaseClient`]|list[str]): The selected clients. """ifclients_per_round>len(clients):logger.warning("Available clients for selection are smaller than required clients for each round")clients_per_round=min(clients_per_round,len(clients))ifself.conf.server.random_selection:np.random.seed(self.current_round)self.selected_clients=np.random.choice(clients,clients_per_round,replace=False)else:self.selected_clients=clients[:clients_per_round]returnself.selected_clients
[docs]defgrouping_for_distributed(self):"""Divide the selected clients into groups for distributed training. Each group of clients is assigned to conduct training in one GPU. The number of groups = the number of gpus. Not in distributed training, selected clients are in the same group. In distributed, selected clients are grouped with different strategies: greedy and random. """ifself.conf.is_distributed:groups=grouping(self.selected_clients,self.conf.distributed.world_size,self.default_time,self.conf.resource_heterogeneous.grouping_strategy,self.current_round)# assign a group for each rank to train with current device.self.grouped_clients=groups[self.conf.distributed.rank]grouping_info=[(c.cid,c.round_time)forcinself.grouped_clients]logger.info("Grouping Result for rank {}: {}".format(self.conf.distributed.rank,grouping_info))else:self.grouped_clients=self.selected_clientsrank=0iflen(self.grouped_clients)==len(self.selected_clients)elseself.conf.distributed.rank
[docs]defcompression(self):"""Model compression to reduce communication cost."""pass
[docs]defdistribution_to_train(self):"""Distribute model and configurations to selected clients to train."""ifself.is_remote:self.distribution_to_train_remotely()else:self.distribution_to_train_locally()# Adaptively update the training time of clients for greedy grouping.ifself.conf.is_distributedandself.conf.resource_heterogeneous.grouping_strategy==GREEDY_GROUPING:self.profile_training_speed()self.update_default_time()
[docs]defdistribution_to_train_locally(self):"""Conduct training sequentially for selected clients in the group."""uploaded_models={}uploaded_weights={}uploaded_metrics=[]forclientinself.grouped_clients:# Update client config before trainingself.conf.client.task_id=self.conf.task_idself.conf.client.round_id=self.current_rounduploaded_request=client.run_train(self.model,self.conf.client)uploaded_content=uploaded_request.contentmodel=self.decompression(codec.unmarshal(uploaded_content.data))uploaded_models[client.cid]=modeluploaded_weights[client.cid]=uploaded_content.data_sizeuploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))self.set_client_uploads_train(uploaded_models,uploaded_weights,uploaded_metrics)
[docs]defdistribution_to_train_remotely(self):"""Distribute training requests to remote clients through multiple threads. The main thread waits for signal to proceed. The signal can be triggered via notification, as below example. Example to trigger signal: >>> with self.condition(): >>> self.notify_all() """start_time=time.time()should_track=self.trackerisnotNoneandself.conf.client.trackwithconcurrent.futures.ThreadPoolExecutor()asexecutor:forclientinself.grouped_clients:request=client_pb.OperateRequest(type=client_pb.OP_TYPE_TRAIN,model=codec.marshal(self.model),data_index=client.index,config=client_pb.OperateConfig(batch_size=self.conf.client.batch_size,local_epoch=self.conf.client.local_epoch,seed=self.conf.seed,local_test=self.conf.client.local_test,optimizer=client_pb.Optimizer(type=self.conf.client.optimizer.type,lr=self.conf.client.optimizer.lr,momentum=self.conf.client.optimizer.momentum,),task_id=self.conf.task_id,round_id=self.current_round,track=should_track,),)executor.submit(self._distribution_remotely,client.client_id,request)distribute_time=time.time()-start_timeself.track(metric.TRAIN_DISTRIBUTE_TIME,distribute_time)logger.info("Distribute to clients, time: {}".format(distribute_time))withself._condition:self._condition.wait()
[docs]defdistribution_to_test(self):"""Distribute to conduct testing on clients."""ifself.is_remote:self.distribution_to_test_remotely()else:self.distribution_to_test_locally()
[docs]defdistribution_to_test_locally(self):"""Conduct testing sequentially for selected testing clients."""uploaded_performance={}uploaded_data_sizes=[]uploaded_metrics=[]test_clients=self.get_test_clients()forclientintest_clients:# Update client config before testingself.conf.client.task_id=self.conf.task_idself.conf.client.round_id=self.current_rounduploaded_request=client.run_test(self.model,self.conf.client)uploaded_content=uploaded_request.contentm=codec.unmarshal(uploaded_content.data)forkeyinm.metric.keys():ifkeyinuploaded_performance:uploaded_performance[key].append(m.metric[key])else:uploaded_performance[key]=[m.metric[key]]uploaded_data_sizes.append(uploaded_content.data_size)uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))self.set_client_uploads_test(uploaded_performance,uploaded_data_sizes,uploaded_metrics)
[docs]defdistribution_to_test_remotely(self):"""Distribute testing requests to remote clients through multiple threads. The main thread waits for signal to proceed. The signal can be triggered via notification, as below example. Example to trigger signal: >>> with self.condition(): >>> self.notify_all() """start_time=time.time()should_track=self.trackerisnotNoneandself.conf.client.tracktest_clients=self.get_test_clients()withconcurrent.futures.ThreadPoolExecutor()asexecutor:forclientintest_clients:request=client_pb.OperateRequest(type=client_pb.OP_TYPE_TEST,model=codec.marshal(self.model),data_index=client.index,config=client_pb.OperateConfig(batch_size=self.conf.client.batch_size,test_batch_size=self.conf.client.test_batch_size,seed=self.conf.seed,task_id=self.conf.task_id,round_id=self.current_round,track=should_track,))executor.submit(self._distribution_remotely,client.client_id,request)distribute_time=time.time()-start_timeself.track(metric.TEST_DISTRIBUTE_TIME,distribute_time)logger.info("Distribute to test clients, time: {}".format(distribute_time))withself._condition:self._condition.wait()
[docs]defget_test_clients(self):"""Get clients to run testing. Returns: (list[:obj:`BaseClient`]|list[str]): Clients to test. """ifself.conf.server.test_all:ifself.conf.is_distributed:# Group and assign clients to different hardware devices to test.test_clients=grouping(self.clients,self.conf.distributed.world_size,default_time=self.default_time,strategy=self.conf.resource_heterogeneous.grouping_strategy)test_clients=test_clients[self.conf.distributed.rank]else:test_clients=self.clientselse:# For the initial testing, if no clients are selected, test all clientstest_clients=self.grouped_clientsifself.grouped_clientsisnotNoneelseself.clientsreturntest_clients
def_distribution_remotely(self,cid,request):"""Distribute request to the assigned client to conduct operations. Args: cid (str): Client id. request (:obj:`OperateRequest`): gRPC request of specific operations. """resp=self.client_stubs[cid].Operate(request)ifresp.status.code!=common_pb.SC_OK:logger.error("Failed to train/test in client {}, error: {}".format(cid,resp.status.message))else:logger.info("Distribute to train/test remotely successfully, client: {}".format(cid))
[docs]defaggregation_test(self):"""Aggregate testing results from clients. Returns: dict: Test metrics, format in {"test_metric": dict} """test_results=self.client_uploads[METRIC]test_sizes=self.client_uploads[DATA_SIZE]avg_results={}metric_keys=sorted(test_results.keys())ifself.conf.test_method=="average":forkeyinmetric_keys:avg_results[key]=self._mean_value(test_results[key])elifself.conf.test_method=="weighted":forkeyinmetric_keys:avg_results[key]=self._weighted_value(test_results[key],test_sizes)else:raiseValueError("test_method not supported, please use average or weighted")test_results={metric.TEST_METRIC:avg_results,}returntest_results
[docs]defdecompression(self,model):"""Decompression the models from clients"""returnmodel
[docs]defaggregation(self):"""Aggregate training updates from clients. Server aggregates trained models from clients via federated averaging. """uploaded_content=self.get_client_uploads()models=list(uploaded_content[MODEL].values())weights=list(uploaded_content[DATA_SIZE].values())model=self.aggregate(models,weights)self.set_model(model,load_dict=True)
[docs]defaggregate(self,models,weights):"""Aggregate models uploaded from clients via federated averaging. Overwrite this method for different training backend, default is for PyTorch. Args: models (list[nn.Module]): List of models. weights (list[float]): List of weights, corresponding to each model. Weights are dataset size of clients by default. Returns nn.Module: Aggregated model. """ifself.conf.server.aggregation_strategy==EQUAL_AVERAGE:weights=[1for_inrange(len(models))]fn_average=strategies.federated_averagingfn_sum=strategies.weighted_sumfn_reduce=reduce_modelsifself.conf.server.aggregation_content==AGGREGATION_CONTENT_PARAMS:fn_average=strategies.federated_averaging_only_paramsfn_sum=strategies.weighted_sum_only_paramsfn_reduce=reduce_models_only_paramsifself.conf.is_distributed:dist.barrier()model,sample_sum=fn_sum(models,weights)fn_reduce(model,torch.tensor(sample_sum).to(self.conf.device))else:model=fn_average(models,weights)returnmodel
[docs]defis_training(self):"""Check whether the server is in training or has stopped training. Returns: bool: A flag to indicate whether server is in training. """returnself.is_training
[docs]defset_model(self,model,load_dict=False):"""Update the universal model in the server. Overwrite this method for different training backend, default is for PyTorch. Args: model (nn.Module): New model. load_dict (bool): A flag to indicate whether load state dict or copy the model. """ifload_dict:self.model.load_state_dict(model.state_dict())else:self.model=copy.deepcopy(model)
[docs]defsave_model(self):"""Save the model in the server. Overwrite this method for different training backend, default is PyTorch. """ifself._do_every(self.conf.server.save_model_every,self.current_round,self.conf.server.rounds)and \
self.is_primary_server():save_path=self.conf.server.save_model_pathifsave_path=="":save_path=os.path.join(os.getcwd(),"saved_models")os.makedirs(save_path,exist_ok=True)save_path=os.path.join(save_path,"{}_global_model_r_{}.pth".format(self.conf.task_id,self.current_round))torch.save(self.model.cpu().state_dict(),save_path)self.print_("Model saved at {}".format(save_path))
[docs]defset_client_uploads_customize_content(self,content1,content2,name1="content1",name2="content2"):"""Set training updates uploaded from clients. Args: content1 (dict): A collection of content. content2 (dict): A collection of content. name1 (str): name of content1. name2 (str): name of content2. """self.set_client_uploads(name1,content1)self.set_client_uploads(name2,content2)
[docs]defset_client_uploads_train(self,models,weights,metrics=None):"""Set training updates uploaded from clients. Args: models (dict): A collection of models. weights (dict): A collection of weights. metrics (dict): Client training metrics. """self.set_client_uploads(MODEL,models)self.set_client_uploads(DATA_SIZE,weights)ifself._should_gather_metrics():metrics=self.gather_client_train_metrics()self.set_client_uploads(CLIENT_METRICS,metrics)
[docs]defset_client_uploads_test(self,metrics,test_sizes,client_metrics=None):"""Set testing results uploaded from clients. Args: metrics (dict[list]): Testing metrics of clients. test_sizes (list[float]): Test dataset sizes of clients. client_metrics (dict): Client testing metrics. """self.set_client_uploads(METRIC,metrics)self.set_client_uploads(DATA_SIZE,test_sizes)ifself._should_gather_metrics()andCLIENT_METRICSinself.client_uploads:train_metrics=self.get_client_uploads()[CLIENT_METRICS]client_metrics=metric.ClientMetric.merge_train_to_test_metrics(train_metrics,client_metrics)self.set_client_uploads(CLIENT_METRICS,client_metrics)
[docs]defset_client_uploads(self,key,value):"""A general function to set uploaded content from clients. Args: key (str): Dictionary key. value (*): Uploaded content. """self.client_uploads[key]=value
[docs]defget_client_uploads(self):"""Get client uploaded contents. Returns: dict: A dictionary that contains client uploaded contents. """returnself.client_uploads
[docs]defprint_(self,content):"""print only the server is primary server. Args: content (str): The content to log. """ifself.is_primary_server():logger.info(content)
[docs]defis_primary_server(self):"""Check whether the current process is the primary server. In standalone or remote training, the server is primary. In distributed training, the server on rank0 is primary. Returns: bool: A flag to indicate whether current process is the primary server. """returnnotself.conf.is_distributedorself.conf.distributed.rank==0
# Functions for remote training
[docs]defstart_service(self):"""Start federated learning server GRPC service."""ifself.is_remote:grpc_wrapper.start_service(grpc_wrapper.TYPE_SERVER,ServerService(self),self.local_port)logger.info("GRPC server started at :{}".format(self.local_port))
defconnect_remote_clients(self,clients):# TODO: This client should be consistent with client started separately.forclientinclients:ifclient.client_idnotinself.client_stubs:self.client_stubs[client.client_id]=grpc_wrapper.init_stub(grpc_wrapper.TYPE_CLIENT,client.address)logger.info("Successfully connected to gRPC client {}".format(client.address))
[docs]definit_etcd(self,addresses):"""Initialize etcd as the registry for client registration. Args: addresses (str): The etcd addresses split by "," """self._etcd=EtcdClient("server",addresses,"backends")
[docs]defstart_remote_training(self,model,clients):"""Start federated learning in the remote training mode. Server establishes gPRC connection with clients that are not connected first before training. Args: model (nn.Module): The model to train. clients (list[str]): Client addresses. """self.connect_remote_clients(clients)self.start(model,clients)
[docs]deftrack(self,metric_name,value):"""Track a metric. Args: metric_name (str): Name of the metric of a round. value (str|int|float|bool|dict|list): Value of the metric. """ifnotself._should_track():returnself.tracker.track_round(metric_name,value)
[docs]deftrack_test_results(self,results):"""Track test results collected from clients. Args: results (dict): Test metrics, format in {"test_metric": dict, "test_time": value} """self.cumulative_times.append(time.time()-self._start_time)test_metrics=results[metric.TEST_METRIC]forkey,valueintest_metrics.items():ifkeyinself.performance_metrics:self.performance_metrics[key].append(value)else:self.performance_metrics[key]=[value]formetric_nameinresults:self.track(metric_name,results[metric_name])test_metric_content=''.join([", Test {}: {:.4f}".format(key,value)forkey,valueintest_metrics.items()])self.print_('Test time: {:.2f}s'.format(results[metric.TEST_TIME])+test_metric_content)
[docs]defsave_tracker(self):"""Save metrics in the tracker to database."""ifself.tracker:self.track_communication_cost()ifself.is_primary_server():self.tracker.save_round()# In distributed training, each server saves their clients separately.self.tracker.save_clients(self.client_uploads[CLIENT_METRICS])
[docs]deftrack_communication_cost(self):"""Track communication cost among server and clients. Communication cost occurs in `training` and `testing` with downlink and uplink costs. """train_upload_size=0train_download_size=0test_upload_size=0test_download_size=0forclient_metricinself.client_uploads[CLIENT_METRICS]:ifclient_metric.round_id==self.current_roundandclient_metric.task_id==self.conf.task_id:train_upload_size+=client_metric.train_upload_sizetrain_download_size+=client_metric.train_download_sizetest_upload_size+=client_metric.test_upload_sizetest_download_size+=client_metric.test_download_sizeifself.conf.is_distributed:train_upload_size=reduce_value(train_upload_size,self.conf.device).item()train_download_size=reduce_value(train_download_size,self.conf.device).item()test_upload_size=reduce_value(test_upload_size,self.conf.device).item()test_download_size=reduce_value(test_download_size,self.conf.device).item()self.tracker.track_round(metric.TRAIN_UPLOAD_SIZE,train_upload_size)self.tracker.track_round(metric.TRAIN_DOWNLOAD_SIZE,train_download_size)self.tracker.track_round(metric.TEST_UPLOAD_SIZE,test_upload_size)self.tracker.track_round(metric.TEST_DOWNLOAD_SIZE,test_download_size)
def_should_track(self):"""Check whether server should track metrics. Server tracks metrics only when tracking is enabled, and it is the primary server. Returns: bool: A flag indicate whether server should track metrics. """returnself.trackerisnotNoneandself.is_primary_server()def_should_gather_metrics(self):"""Check whether the server should gather metrics from GPUs. Gather metrics only when testing all in `distributed` training. Testing all resets clients' training metrics, thus, server needs to gather train metrics to construct full client metrics. Returns: bool: A flag indicate whether server should gather metrics. """returnself.conf.is_distributedandself.conf.server.test_allandself.tracker
[docs]defgather_client_train_metrics(self):"""Gather client train metrics from other ranks for distributed training, when testing all clients (test_all). When testing all clients, the trained metrics may be override by the test metrics because clients may be placed in different GPUs in training and testing, leading to losses of train metrics. So we gather train metrics and set them in test metrics. TODO: gather is not progressing. Need fix. """world_size=self.conf.distributed.world_sizedevice=self.conf.deviceuploads=self.get_client_uploads()client_id_list=[]train_metric_list_dict={}train_time_list=[]train_upload_time_list=[]train_upload_size_list=[]train_download_size_list=[]forminuploads[CLIENT_METRICS]:# client_id_list += gather_value(m.client_id, world_size, device).tolist()forkeyinm.train_metric.keys():ifkeyintrain_metric_list_dict:train_metric_list_dict[key].append(gather_value(m.train_metric[key],world_size,device))else:train_metric_list_dict[key]=[gather_value(m.train_metric[key],world_size,device)]train_time_list+=gather_value(m.train_time,world_size,device)train_upload_time_list+=gather_value(m.train_upload_time,world_size,device)train_upload_size_list+=gather_value(m.train_upload_size,world_size,device)train_download_size_list+=gather_value(m.train_download_size,world_size,device)metrics=[]# Note: Client id may not match with its training stats because all_gather string is not supported.client_id_list=[c.cidforcinself.selected_clients]fori,client_idinenumerate(client_id_list):m=metric.ClientMetric(self.conf.task_id,self.current_round,client_id)m.add(metric.TRAIN_METRIC,{key:value[i]forkey,valueintrain_metric_list_dict.items()})m.add(metric.TRAIN_TIME,train_time_list[i])m.add(metric.TRAIN_UPLOAD_TIME,train_upload_time_list[i])m.add(metric.TRAIN_UPLOAD_SIZE,train_upload_size_list[i])m.add(metric.TRAIN_DOWNLOAD_SIZE,train_download_size_list[i])metrics.append(m)returnmetrics
# Functions for remote training.defcondition(self):returnself._conditiondefnotify_all(self):self._condition.notify_all()# Functions for distributed training optimization.
[docs]defprofile_training_speed(self):"""Manage profiling of client training speeds for distributed training optimization."""profile_required=[]forclientinself.selected_clients:ifnotclient.profiled:profile_required.append(client)iflen(profile_required)>0:original=torch.FloatTensor([c.round_timeforcinprofile_required]).to(self.conf.device)time_update=torch.FloatTensor([c.train_timeforcinprofile_required]).to(self.conf.device)dist.barrier()dist.all_reduce(time_update)foriinrange(len(profile_required)):old_round_time=original[i]current_round_time=time_update[i]ifold_round_time==0orself._should_update_round_time(old_round_time,current_round_time):profile_required[i].round_time=float(current_round_time)profile_required[i].train_time=0else:profile_required[i].profiled=True
[docs]defupdate_default_time(self):"""Update the estimated default training time of clients using actual training time from profiled clients."""default_momentum=self.conf.resource_heterogeneous.default_time_momentumcurrent_round_average=np.mean([float(c.round_time)forcinself.selected_clients])self.default_time=default_momentum*current_round_average+self.default_time*(1-default_momentum)
def_should_update_round_time(self,old_round_time,new_round_time,threshold=0.3):"""Check whether assign a new estimated round time to client or set it to ‘profiled’. Args: old_round_time (float): previous estimated round time. new_round_time (float): Currently profiled round time. threshold (float): Tolerance threshold of difference between old and new times. Returns: bool: A flag to indicate whether to update round time or not. """ifnew_round_time<old_round_time:return((old_round_time-new_round_time)/new_round_time)>=thresholdelse:return((new_round_time-old_round_time)/old_round_time)>=threshold
[docs]deftracking_visualization(self,results):""" Args: results (dict): training and test metrics need tracking """pass
[docs]definit_visualization(self):""" init the external visualization tool, e.g., wandb, tensorboard """pass