importloggingimportosimportrandomimportsysimporttimefromosimportpathimportnumpyasnpimporttorchfromomegaconfimportOmegaConffromcoala.client.baseimportBaseClientfromcoala.client.base_semiimportSemiFLClientfromcoala.datasetsimportTEST_IN_SERVERfromcoala.datasets.dataimportconstruct_datasets,construct_datasets_semifromcoala.distributedimportdist_init,get_devicefromcoala.models.modelimportload_modelfromcoala.server.baseimportBaseServerfromcoala.server.base_semiimportSemiFLServerfromcoala.simulation.system_heteroimportresource_hetero_simulationlogger=logging.getLogger(__name__)classCoordinator(object):"""Coordinator manages federated learning server and client. A single instance of coordinator is initialized for each federated learning task when the package is imported. """def__init__(self):self.registered_model=Falseself.registered_dataset=Falseself.registered_server=Falseself.registered_client=Falseself.train_data=Noneself.test_data=Noneself.val_data=Noneself.conf=Noneself.model=Noneself._model_class=Noneself.server=Noneself._server_class=Noneself.clients=Noneself._client_class=Noneself.tracker=Noneself.s_train_data=Noneself.u_train_data=Nonedefinit(self,conf,init_all=True):"""Initialize coordinator Args: conf (omegaconf.dictconfig.DictConfig): Internal configurations for federated learning. init_all (bool): Whether initialize dataset, model, server, and client other than configuration. """self.init_conf(conf)_set_random_seed(conf.seed)ifinit_all:self.init_dataset()self.init_model()self.init_server()self.init_clients()defrun(self):"""Run the coordinator and the federated learning process. Initialize `torch.distributed` if distributed training is configured. """start_time=time.time()ifself.conf.is_distributed:dist_init(self.conf.distributed.backend,self.conf.distributed.init_method,self.conf.distributed.world_size,self.conf.distributed.rank,self.conf.distributed.local_rank,)self.server.start(self.model,self.clients)self.print_("Total training time {:.1f}s".format(time.time()-start_time))definit_conf(self,conf):"""Initialize coordinator configuration. Args: conf (omegaconf.dictconfig.DictConfig): Configurations. """self.conf=confself.conf.is_distributed=(self.conf.gpu>1)ifself.conf.gpu==0:self.conf.device="cpu"elifself.conf.gpu==1:self.conf.device=0else:self.conf.device=get_device(self.conf.gpu,self.conf.distributed.world_size,self.conf.distributed.local_rank)self.print_("Configurations: {}".format(self.conf))definit_dataset(self):"""Initialize datasets. Use provided datasets if not registered."""ifself.registered_dataset:returnifself.conf.data.fl_type=="semi_supervised":self.s_train_data,self.u_train_data,self.test_data=construct_datasets_semi(self.conf.data.root,self.conf.data.dataset,self.conf.data.num_of_clients,self.conf.data.split_type,self.conf.data.min_size,self.conf.data.class_per_client,self.conf.data.data_amount,self.conf.data.iid_fraction,self.conf.data.user,self.conf.data.train_test_split,self.conf.data.weights,self.conf.data.alpha,self.conf.data.semi_scenario,self.conf.data.num_labels_per_class)self.train_data=self.u_train_dataprint(f"Total labeled training data amount: {self.s_train_data.total_size()}")print(f"Total unlabeled training data amount: {self.u_train_data.total_size()}")else:self.train_data,self.test_data=construct_datasets(self.conf.data.root,self.conf.data.dataset,self.conf.data.num_of_clients,self.conf.data.split_type,self.conf.data.min_size,self.conf.data.class_per_client,self.conf.data.data_amount,self.conf.data.iid_fraction,self.conf.data.user,self.conf.data.train_test_split,self.conf.data.weights,self.conf.data.alpha)self.print_(f"Total training data amount: {self.train_data.total_size()}")self.print_(f"Total testing data amount: {self.test_data.total_size()}")definit_model(self):"""Initialize model instance."""ifnotself.registered_model:self._model_class=load_model(self.conf.model)# model_class is None means model is registered as instance, no need initializationifself._model_class:self.model=self._model_class()definit_server(self):"""Initialize a server instance."""ifnotself.registered_server:self._server_class=BaseServerifself.conf.data.fl_type=="semi_supervised":self._server_class=SemiFLServerkwargs={"is_remote":self.conf.is_remote,"local_port":self.conf.local_port}ifself.conf.data.fl_type=="semi_supervised"andself.conf.data.semi_scenario=='label_in_server':kwargs["s_train_data"]=self.s_train_dataifself.conf.test_mode==TEST_IN_SERVER:kwargs["test_data"]=self.test_dataifself.val_data:kwargs["val_data"]=self.val_dataself.server=self._server_class(self.conf,**kwargs)definit_clients(self):"""Initialize client instances, each represents a federated learning client."""ifnotself.registered_client:self._client_class=BaseClientifself.conf.data.fl_type=="semi_supervised":self._client_class=SemiFLClient# Enforce system heterogeneity of clients.sleep_time=[0for_inself.train_data.users]ifself.conf.resource_heterogeneous.simulate:sleep_time=resource_hetero_simulation(self.conf.resource_heterogeneous.fraction,self.conf.resource_heterogeneous.hetero_type,self.conf.resource_heterogeneous.sleep_group_num,self.conf.resource_heterogeneous.level,self.conf.resource_heterogeneous.total_time,len(self.train_data.users))client_test_data=self.test_dataifself.conf.test_mode==TEST_IN_SERVER:client_test_data=Noneifself.conf.data.fl_type=="semi_supervised":labeled_train_data=self.s_train_dataifself.conf.data.semi_scenario=='label_in_client'elseNoneself.clients=[self._client_class(u,self.conf.client,labeled_train_data,self.u_train_data,client_test_data,self.conf.device,**{"sleep_time":sleep_time[i]})fori,uinenumerate(self.u_train_data.users)]elifself.conf.data.fl_type=="self_supervised":raiseNotImplementedErrorelse:self.clients=[self._client_class(u,self.conf.client,self.train_data,client_test_data,self.conf.device,**{"sleep_time":sleep_time[i]})fori,uinenumerate(self.train_data.users)]self.print_("Clients in total: {}".format(len(self.clients)))definit_client(self):"""Initialize client instance. Returns: :obj:`BaseClient`: The initialized client instance. """ifnotself.registered_client:self._client_class=BaseClient# Get a random client if not specifiedifself.conf.index:user=self.train_data.users[self.conf.index]else:user=random.choice(self.train_data.users)returnself._client_class(user,self.conf.client,self.train_data,self.test_data,self.conf.device,is_remote=self.conf.is_remote,local_port=self.conf.local_port,server_addr=self.conf.server_addr,tracker_addr=self.conf.tracker_addr)defstart_server(self,args):"""Start a server service for remote training. Server controls the model and testing dataset if configured to test in server. Args: args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations. """ifargs:self.conf=OmegaConf.merge(self.conf,args.__dict__)ifself.conf.test_mode==TEST_IN_SERVER:self.init_dataset()self.init_model()self.init_server()self.server.start_service()defstart_client(self,args):"""Start a client service for remote training. Client controls training datasets. Args: args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations. """ifargs:self.conf=OmegaConf.merge(self.conf,args.__dict__)self.init_dataset()client=self.init_client()client.start_service()defregister_dataset(self,train_data,test_data,val_data=None):"""Register datasets. Datasets should inherit from :obj:`FederatedDataset`, e.g., :obj:`FederatedTensorDataset`. Args: train_data (:obj:`FederatedDataset`): Training dataset. test_data (:obj:`FederatedDataset`): Testing dataset. val_data (:obj:`FederatedDataset`): Validation dataset. """self.registered_dataset=Trueself.train_data=train_dataself.test_data=test_dataself.val_data=val_datadefregister_model(self,model):"""Register customized model for federated learning. Args: model (nn.Module): PyTorch model, both class and instance are acceptable. Use model class when there is no specific arguments to initialize model. """self.registered_model=Trueifnotisinstance(model,type):self.model=modelelse:self._model_class=modeldefregister_server(self,server):"""Register a customized federated learning server. Args: server (:obj:`BaseServer`): Customized federated learning server. """self.registered_server=Trueself._server_class=serverdefregister_client(self,client):"""Register a customized federated learning client. Args: client (:obj:`BaseClient`): Customized federated learning client. """self.registered_client=Trueself._client_class=clientdefprint_(self,content):"""Log the content only when the server is primary server. Args: content (str): The content to log. """ifself._is_primary_server():logger.info(content)def_is_primary_server(self):"""Check whether current running server is the primary server. In standalone or remote training, the server is primary. In distributed training, the server on `rank0` is primary. """returnnotself.conf.is_distributedorself.conf.distributed.rank==0def_set_random_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)# Initialize the global coordinator object_global_coord=Coordinator()definit_conf(conf=None):"""Initialize configuration for COALA. It overrides and supplements default configuration loaded from config.yaml with the provided configurations. Args: conf (dict): Configurations. Returns: omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf. """here=path.abspath(path.dirname(__file__))config_file=path.join(here,'config.yaml')returnload_config(config_file,conf)
[docs]defload_config(file,conf=None):"""Load and merge configuration from file and input Args: file (str): filename of the configuration. conf (dict): Configurations. Returns: omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf. """config=OmegaConf.load(file)ifconfisnotNone:config=OmegaConf.merge(config,conf)returnconfig
[docs]definit_dataset():"""Initialize dataset, either using registered dataset or out-of-the-box datasets set in config."""global_global_coord_global_coord.init_dataset()
[docs]definit_model():"""Initialize model, either using registered model or out-of–the-box model set in config. Returns: nn.Module: Model used in federated learning. """global_global_coord_global_coord.init_model()return_global_coord.model
[docs]defstart_server(args=None):"""Start federated learning server service for remote training. Args: args (argparse.Namespace): Configurations passed in as arguments. """global_global_coord_global_coord.start_server(args)
[docs]defstart_client(args=None):"""Start federated learning client service for remote training. Args: args (argparse.Namespace): Configurations passed in as arguments. """global_global_coord_global_coord.start_client(args)
defget_coordinator():"""Get the global coordinator instance. Returns: :obj:`Coordinator`: global coordinator instance. """return_global_coord
[docs]defregister_dataset(train_data,test_data,val_data=None):"""Register datasets for federated learning training. Args: train_data (:obj:`FederatedDataset`): Training dataset. test_data (:obj:`FederatedDataset`): Testing dataset. val_data (:obj:`FederatedDataset`): Validation dataset. """global_global_coord_global_coord.register_dataset(train_data,test_data,val_data)
[docs]defregister_model(model):"""Register model for federated learning training. Args: model (nn.Module): PyTorch model, both class and instance are acceptable. """global_global_coord_global_coord.register_model(model)