importargparseimportcopyimportloggingimporttimeimporttorchfromcoala.client.serviceimportClientServicefromcoala.communicationimportgrpc_wrapperfromcoala.distributed.distributedimportCPUfromcoala.pbimportcommon_pb2ascommon_pbfromcoala.pbimportserver_service_pb2asserver_pbfromcoala.protocolimportcodecfromcoala.trackingimportmetricfromcoala.tracking.clientimportinit_trackingfromcoala.tracking.evaluationimportbit_to_megabytelogger=logging.getLogger(__name__)defcreate_argument_parser():"""Create argument parser with arguments/configurations for starting remote client service. Returns: argparse.ArgumentParser: Parser with client service arguments. """parser=argparse.ArgumentParser(description='Federated Client')parser.add_argument('--local-port',type=int,default=23000,help='Listen port of the client')parser.add_argument('--server-addr',type=str,default="localhost:22999",help='Address of server in [IP]:[PORT] format')parser.add_argument('--tracker-addr',type=str,default="localhost:12666",help='Address of tracking service in [IP]:[PORT] format')parser.add_argument('--is-remote',type=bool,default=False,help='Whether start as a remote client.')returnparser
[docs]classBaseClient(object):"""Default implementation of federated learning client. Args: cid (str): Client id. conf (omegaconf.dictconfig.DictConfig): Client configurations. train_data (:obj:`FederatedDataset`): Training dataset. test_data (:obj:`FederatedDataset`): Test dataset. device (str): Hardware device for training, cpu or cuda devices. sleep_time (float): Duration of on hold after training to simulate stragglers. is_remote (bool): Whether start remote training. local_port (int): Port of remote client service. server_addr (str): Remote server service grpc address. tracker_addr (str): Remote tracking service grpc address. Override the class and functions to implement customized client. Example: >>> from coala.client import BaseClient >>> class CustomizedClient(BaseClient): >>> def __init__(self, cid, conf, train_data, test_data, device, **kwargs): >>> super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs) >>> pass # more initialization of attributes. >>> >>> def train(self, conf, device=CPU): >>> # Implement customized client training method, which overwrites the default training method. >>> pass """def__init__(self,cid,conf,train_data,test_data,device,sleep_time=0,is_remote=False,local_port=23000,server_addr="localhost:22999",tracker_addr="localhost:12666"):self.cid=cidself.conf=confself.train_data=train_dataself.train_loader=Noneself.test_data=test_dataself.test_loader=Noneself.device=deviceself.round_time=0self.train_time=0self.test_time=0self.train_accuracy=[]self.train_loss=[]self.test_accuracy=0self.test_loss=0self.test_metric={}self._datasize=self.train_data.size(self.cid)ifself.train_dataelse0self.profiled=Falseself._sleep_time=sleep_timeself.model=Noneself._upload_holder=server_pb.UploadContent()self.is_remote=is_remoteself.local_port=local_portself._server_addr=server_addrself._tracker_addr=tracker_addrself._server_stub=Noneself._tracker=Noneself._is_train=Trueifconf.track:self._tracker=init_tracking(init_store=False)
[docs]defrun_train(self,model,conf):"""Conduct training on clients. Args: model (nn.Module): Model to train. conf (omegaconf.dictconfig.DictConfig): Client configurations. Returns: :obj:`UploadRequest`: Training contents. Unify the interface for both local and remote operations. """self.conf=confifconf.track:self._tracker.set_client_context(conf.task_id,conf.round_id,self.cid)self._is_train=Trueself.set_model(model)self.track(metric.TRAIN_DOWNLOAD_SIZE,self.calculate_model_size(model))self.decompression()self.pre_train()self.train(conf,self.device)self.post_train()self.track(metric.TRAIN_METRIC,{"accuracy":self.train_accuracy,"loss":self.train_loss})self.track(metric.TRAIN_TIME,self.train_time)ifconf.local_test:self.test_local()self.compression()self.track(metric.TRAIN_UPLOAD_SIZE,self.calculate_model_size(self.model))self.encryption()returnself.upload()
[docs]defrun_test(self,model,conf):"""Conduct testing on clients. Args: model (nn.Module): Model to test. conf (omegaconf.dictconfig.DictConfig): Client configurations. Returns: :obj:`UploadRequest`: Testing contents. Unify the interface for both local and remote operations. """self.conf=confifconf.track:reset=notself._is_trainself._tracker.set_client_context(conf.task_id,conf.round_id,self.cid,reset_client=reset)self._is_train=Falseself.set_model(model)self.track(metric.TEST_DOWNLOAD_SIZE,self.calculate_model_size(model))self.decompression()self.pre_test()self.test(conf,self.device)self.post_test()self.track(metric.TEST_METRIC,{"accuracy":float(self.test_accuracy),"loss":float(self.test_loss)})self.track(metric.TEST_TIME,self.test_time)returnself.upload()
[docs]defset_model(self,model):"""Set the given model as the client model. This method should be overwritten for different training backend, the default is PyTorch. Args: model (options: nn.Module, tf.keras.Model, ...): Global model distributed from the server. """ifself.model:self.model.load_state_dict(model.state_dict())else:self.model=copy.deepcopy(model)
[docs]defdecompression(self):"""Decompressed model. It can be further implemented when the model is compressed in the server."""pass
[docs]defpre_train(self):"""Preprocessing before training."""pass
[docs]deftrain(self,conf,device=CPU):"""Execute client training. Args: conf (omegaconf.dictconfig.DictConfig): Client configurations. device (str): Hardware device for training, cpu or cuda devices. """start_time=time.time()loss_fn,optimizer=self.pretrain_setup(conf,device)self.train_loss=[]foriinrange(conf.local_epoch):batch_loss=[]forbatched_x,batched_yinself.train_loader:x,y=batched_x.to(device),batched_y.to(device)optimizer.zero_grad()out=self.model(x)loss=loss_fn(out,y)loss.backward()optimizer.step()batch_loss.append(loss.item())current_epoch_loss=sum(batch_loss)/len(batch_loss)self.train_loss.append(float(current_epoch_loss))logger.debug("Client {}, local epoch: {}, loss: {}".format(self.cid,i,current_epoch_loss))self.train_time=time.time()-start_timelogger.debug("Client {}, Train Time: {}".format(self.cid,self.train_time))
[docs]defpost_train(self):"""Postprocessing after training."""pass
[docs]defpretrain_setup(self,conf,device):"""Setup loss function and optimizer before training."""self.simulate_straggler()self.model.train()self.model.to(device)loss_fn=self.load_loss_fn(conf)optimizer=self.load_optimizer(conf)ifself.train_loaderisNone:self.train_loader=self.load_loader(conf)returnloss_fn,optimizer
[docs]defload_optimizer(self,conf):"""Load training optimizer. Implemented Adam and SGD."""ifconf.optimizer.type=="Adam":optimizer=torch.optim.Adam(self.model.parameters(),lr=conf.optimizer.lr)else:# default using optimizer SGDoptimizer=torch.optim.SGD(self.model.parameters(),lr=conf.optimizer.lr,momentum=conf.optimizer.momentum,weight_decay=conf.optimizer.weight_decay)returnoptimizer
[docs]defload_loader(self,conf):"""Load the training data loader. Args: conf (omegaconf.dictconfig.DictConfig): Client configurations. Returns: torch.utils.data.DataLoader: Data loader. """returnself.train_data.loader(conf.batch_size,self.cid,shuffle=True,seed=conf.seed)
[docs]deftest_local(self):"""Test client local model after training."""pass
[docs]defpre_test(self):"""Preprocessing before testing."""pass
[docs]deftest(self,conf,device=CPU):"""Execute client testing. Args: conf (omegaconf.dictconfig.DictConfig): Client configurations. device (str): Hardware device for training, cpu or cuda devices. """begin_test_time=time.time()self.model.eval()self.model.to(device)loss_fn=self.load_loss_fn(conf)ifself.test_loaderisNone:self.test_loader=self.test_data.loader(conf.test_batch_size,self.cid,shuffle=False,seed=conf.seed)# TODO: make evaluation metrics a separate package and apply it here.self.test_loss=0correct=0withtorch.no_grad():forbatched_x,batched_yinself.test_loader:x=batched_x.to(device)y=batched_y.to(device)log_probs=self.model(x)loss=loss_fn(log_probs,y)_,y_pred=torch.max(log_probs,-1)correct+=y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()self.test_loss+=loss.item()test_size=self.test_data.size(self.cid)self.test_loss/=len(self.test_loader)self.test_accuracy=100.0*float(correct)/test_sizelogger.debug('Client {}, testing -- Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(self.cid,self.test_loss,correct,test_size,self.test_accuracy))self.test_time=time.time()-begin_test_timeself.test_metric={"accuracy":self.test_accuracy,"loss":self.test_loss}self.model=self.model.cpu()
[docs]defpost_test(self):"""Postprocessing after testing."""pass
[docs]defencryption(self):"""Encrypt the client local model."""# TODO: encryption of model, remember to track encrypted model instead of compressed one after implementation.pass
[docs]defcompression(self):"""Compress the client local model after training and before uploading to the server."""pass
[docs]defupload(self):"""Upload the messages from client to the server. Returns: :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations. Only applicable for local training as remote training upload through a gRPC request. """request=self.construct_upload_request()ifnotself.is_remote:self.post_upload()returnrequestself.upload_remotely(request)self.post_upload()
[docs]defpost_upload(self):"""Postprocessing after uploading training/testing results."""pass
[docs]defconstruct_upload_request(self):"""Construct client upload request for training updates and testing results. Returns: :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations. """data=codec.marshal(server_pb.Performance(metric=self.test_metric))typ=common_pb.DATA_TYPE_PERFORMANCEtry:ifself._is_train:data=codec.marshal(copy.deepcopy(self.model))typ=common_pb.DATA_TYPE_PARAMSdata_size=self._datasizeelse:data_size=1ifnotself.test_dataelseself.test_data.size(self.cid)exceptKeyError:# When the datasize cannot be obtained from dataset, default to use equal aggregatedata_size=1m=self._tracker.get_client_metric().to_proto()ifself._trackerelsecommon_pb.ClientMetric()returnserver_pb.UploadRequest(task_id=self.conf.task_id,round_id=self.conf.round_id,client_id=self.cid,content=server_pb.UploadContent(data=data,type=typ,data_size=data_size,metric=m,),)
[docs]defupload_remotely(self,request):"""Send upload request to remote server via gRPC. Args: request (:obj:`UploadRequest`): Upload request. """start_time=time.time()self.connect_to_server()resp=self._server_stub.Upload(request)upload_time=time.time()-start_timem=metric.TRAIN_UPLOAD_TIMEifself._is_trainelsemetric.TEST_UPLOAD_TIMEself.track(m,upload_time)logger.info("client upload time: {}s".format(upload_time))ifresp.status.code==common_pb.SC_OK:logger.info("Uploaded remotely to the server successfully\n")else:logger.error("Failed to upload, code: {}, message: {}\n".format(resp.status.code,resp.status.message))
[docs]defconnect_to_server(self):"""Establish connection between the client and the server."""ifself.is_remoteandself._server_stubisNone:self._server_stub=grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER,self._server_addr)logger.info("Successfully connected to gRPC server {}".format(self._server_addr))
[docs]defoperate(self,model,conf,index,is_train=True):"""A wrapper over operations (training/testing) on clients. Args: model (nn.Module): Model for operations. conf (omegaconf.dictconfig.DictConfig): Client configurations. index (int): Client index in the client list, for retrieving data. TODO: improvement. is_train (bool): The flag to indicate whether the operation is training, otherwise testing. """try:# Load the data index depending on server requestself.cid=self.train_data.users[index]exceptIndexError:logger.error("Data index exceed the available data, abort training")returnifself.conf.trackandself._trackerisNone:self._tracker=init_tracking(init_store=False)ifis_train:logger.info("Train on data index {}, client: {}".format(index,self.cid))self.run_train(model,conf)else:logger.info("Test on data index {}, client: {}".format(index,self.cid))self.run_test(model,conf)
# Functions for tracking.
[docs]deftrack(self,metric_name,value):"""Track a metric. Args: metric_name (str): The name of the metric. value (str|int|float|bool|dict|list): The value of the metric. """ifnotself.conf.trackorself._trackerisNone:logger.debug("Tracker not available, Tracking not supported")returnself._tracker.track_client(metric_name,value)
[docs]defsave_metrics(self):"""Save client metrics to database."""# TODO: not testedifself._trackerisNone:logger.debug("Tracker not available, no saving")returnself._tracker.save_client()
# Functions for simulation.
[docs]defsimulate_straggler(self):"""Simulate straggler effect of system heterogeneity."""ifself._sleep_time>0:time.sleep(self._sleep_time)
[docs]defcalculate_model_size(self,model,param_size=32):"""Calculate the model parameter sizes, including non-trainable parameters. Should be overwritten for different training backend. Args: model (options: nn.Module, tf.keras.Model, ...): A model. param_size (int): The size of a parameter, default using float32. Returns: float: The model size in MB. """# sum(p.numel() for p in model.parameters() if p.requires_grad) for only trainable parametersparams=sum(p.numel()forpinmodel.parameters())returnbit_to_megabyte(params*param_size)