Shortcuts

Source code for coala.coordinator

import logging
import os
import random
import sys
import time
from os import path

import numpy as np
import torch
from omegaconf import OmegaConf

from coala.client.base import BaseClient
from coala.client.base_semi import SemiFLClient
from coala.datasets import TEST_IN_SERVER
from coala.datasets.data import construct_datasets, construct_datasets_semi
from coala.distributed import dist_init, get_device
from coala.models.model import load_model
from coala.server.base import BaseServer
from coala.server.base_semi import SemiFLServer
from coala.simulation.system_hetero import resource_hetero_simulation

logger = logging.getLogger(__name__)


class Coordinator(object):
    """Coordinator manages federated learning server and client.
    A single instance of coordinator is initialized for each federated learning task
    when the package is imported.
    """

    def __init__(self):
        self.registered_model = False
        self.registered_dataset = False
        self.registered_server = False
        self.registered_client = False
        self.train_data = None
        self.test_data = None
        self.val_data = None
        self.conf = None
        self.model = None
        self._model_class = None
        self.server = None
        self._server_class = None
        self.clients = None
        self._client_class = None
        self.tracker = None
        self.s_train_data = None
        self.u_train_data = None

    def init(self, conf, init_all=True):
        """Initialize coordinator

        Args:
            conf (omegaconf.dictconfig.DictConfig): Internal configurations for federated learning.
            init_all (bool): Whether initialize dataset, model, server, and client other than configuration.
        """
        self.init_conf(conf)

        _set_random_seed(conf.seed)

        if init_all:
            self.init_dataset()

            self.init_model()

            self.init_server()

            self.init_clients()

    def run(self):
        """Run the coordinator and the federated learning process.
        Initialize `torch.distributed` if distributed training is configured.
        """
        start_time = time.time()

        if self.conf.is_distributed:
            dist_init(
                self.conf.distributed.backend,
                self.conf.distributed.init_method,
                self.conf.distributed.world_size,
                self.conf.distributed.rank,
                self.conf.distributed.local_rank,
            )
        self.server.start(self.model, self.clients)
        self.print_("Total training time {:.1f}s".format(time.time() - start_time))

    def init_conf(self, conf):
        """Initialize coordinator configuration.

        Args:
            conf (omegaconf.dictconfig.DictConfig): Configurations.
        """
        self.conf = conf
        self.conf.is_distributed = (self.conf.gpu > 1)
        if self.conf.gpu == 0:
            self.conf.device = "cpu"
        elif self.conf.gpu == 1:
            self.conf.device = 0
        else:
            self.conf.device = get_device(self.conf.gpu, self.conf.distributed.world_size,
                                          self.conf.distributed.local_rank)
        self.print_("Configurations: {}".format(self.conf))

    def init_dataset(self):
        """Initialize datasets. Use provided datasets if not registered."""
        if self.registered_dataset:
            return
        if self.conf.data.fl_type == "semi_supervised":
            self.s_train_data, self.u_train_data, self.test_data = construct_datasets_semi(self.conf.data.root,
                                                                                           self.conf.data.dataset,
                                                                                           self.conf.data.num_of_clients,
                                                                                           self.conf.data.split_type,
                                                                                           self.conf.data.min_size,
                                                                                           self.conf.data.class_per_client,
                                                                                           self.conf.data.data_amount,
                                                                                           self.conf.data.iid_fraction,
                                                                                           self.conf.data.user,
                                                                                           self.conf.data.train_test_split,
                                                                                           self.conf.data.weights,
                                                                                           self.conf.data.alpha,
                                                                                           self.conf.data.semi_scenario,
                                                                                           self.conf.data.num_labels_per_class)
            self.train_data = self.u_train_data
            print(f"Total labeled training data amount: {self.s_train_data.total_size()}")
            print(f"Total unlabeled training data amount: {self.u_train_data.total_size()}")
        else:
            self.train_data, self.test_data = construct_datasets(self.conf.data.root,
                                                                 self.conf.data.dataset,
                                                                 self.conf.data.num_of_clients,
                                                                 self.conf.data.split_type,
                                                                 self.conf.data.min_size,
                                                                 self.conf.data.class_per_client,
                                                                 self.conf.data.data_amount,
                                                                 self.conf.data.iid_fraction,
                                                                 self.conf.data.user,
                                                                 self.conf.data.train_test_split,
                                                                 self.conf.data.weights,
                                                                 self.conf.data.alpha)

            self.print_(f"Total training data amount: {self.train_data.total_size()}")
            self.print_(f"Total testing data amount: {self.test_data.total_size()}")

    def init_model(self):
        """Initialize model instance."""
        if not self.registered_model:
            self._model_class = load_model(self.conf.model)

        # model_class is None means model is registered as instance, no need initialization
        if self._model_class:
            self.model = self._model_class()

    def init_server(self):
        """Initialize a server instance."""
        if not self.registered_server:
            self._server_class = BaseServer
            if self.conf.data.fl_type == "semi_supervised":
                self._server_class = SemiFLServer

        kwargs = {
            "is_remote": self.conf.is_remote,
            "local_port": self.conf.local_port
        }

        if self.conf.data.fl_type == "semi_supervised" and self.conf.data.semi_scenario == 'label_in_server':
            kwargs["s_train_data"] = self.s_train_data

        if self.conf.test_mode == TEST_IN_SERVER:
            kwargs["test_data"] = self.test_data
            if self.val_data:
                kwargs["val_data"] = self.val_data

        self.server = self._server_class(self.conf, **kwargs)

    def init_clients(self):
        """Initialize client instances, each represents a federated learning client."""
        if not self.registered_client:
            self._client_class = BaseClient
            if self.conf.data.fl_type == "semi_supervised":
                self._client_class = SemiFLClient

        # Enforce system heterogeneity of clients.
        sleep_time = [0 for _ in self.train_data.users]
        if self.conf.resource_heterogeneous.simulate:
            sleep_time = resource_hetero_simulation(self.conf.resource_heterogeneous.fraction,
                                                    self.conf.resource_heterogeneous.hetero_type,
                                                    self.conf.resource_heterogeneous.sleep_group_num,
                                                    self.conf.resource_heterogeneous.level,
                                                    self.conf.resource_heterogeneous.total_time,
                                                    len(self.train_data.users))

        client_test_data = self.test_data
        if self.conf.test_mode == TEST_IN_SERVER:
            client_test_data = None

        if self.conf.data.fl_type == "semi_supervised":
            labeled_train_data = self.s_train_data if self.conf.data.semi_scenario == 'label_in_client' else None
            self.clients = [self._client_class(u,
                                               self.conf.client,
                                               labeled_train_data,
                                               self.u_train_data,
                                               client_test_data,
                                               self.conf.device,
                                               **{"sleep_time": sleep_time[i]})
                            for i, u in enumerate(self.u_train_data.users)]
        elif self.conf.data.fl_type == "self_supervised":
            raise NotImplementedError
        else:
            self.clients = [self._client_class(u,
                                               self.conf.client,
                                               self.train_data,
                                               client_test_data,
                                               self.conf.device,
                                               **{"sleep_time": sleep_time[i]})
                            for i, u in enumerate(self.train_data.users)]

        self.print_("Clients in total: {}".format(len(self.clients)))

    def init_client(self):
        """Initialize client instance.

        Returns:
            :obj:`BaseClient`: The initialized client instance.
        """
        if not self.registered_client:
            self._client_class = BaseClient

        # Get a random client if not specified
        if self.conf.index:
            user = self.train_data.users[self.conf.index]
        else:
            user = random.choice(self.train_data.users)

        return self._client_class(user,
                                  self.conf.client,
                                  self.train_data,
                                  self.test_data,
                                  self.conf.device,
                                  is_remote=self.conf.is_remote,
                                  local_port=self.conf.local_port,
                                  server_addr=self.conf.server_addr,
                                  tracker_addr=self.conf.tracker_addr)

    def start_server(self, args):
        """Start a server service for remote training.

        Server controls the model and testing dataset if configured to test in server.

        Args:
            args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations.
        """
        if args:
            self.conf = OmegaConf.merge(self.conf, args.__dict__)

        if self.conf.test_mode == TEST_IN_SERVER:
            self.init_dataset()

        self.init_model()

        self.init_server()

        self.server.start_service()

    def start_client(self, args):
        """Start a client service for remote training.

        Client controls training datasets.

        Args:
            args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations.
        """

        if args:
            self.conf = OmegaConf.merge(self.conf, args.__dict__)

        self.init_dataset()

        client = self.init_client()

        client.start_service()

    def register_dataset(self, train_data, test_data, val_data=None):
        """Register datasets.

        Datasets should inherit from :obj:`FederatedDataset`, e.g., :obj:`FederatedTensorDataset`.

        Args:
            train_data (:obj:`FederatedDataset`): Training dataset.
            test_data (:obj:`FederatedDataset`): Testing dataset.
            val_data (:obj:`FederatedDataset`): Validation dataset.
        """
        self.registered_dataset = True
        self.train_data = train_data
        self.test_data = test_data
        self.val_data = val_data

    def register_model(self, model):
        """Register customized model for federated learning.

        Args:
            model (nn.Module): PyTorch model, both class and instance are acceptable.
                Use model class when there is no specific arguments to initialize model.
        """
        self.registered_model = True
        if not isinstance(model, type):
            self.model = model
        else:
            self._model_class = model

    def register_server(self, server):
        """Register a customized federated learning server.

        Args:
            server (:obj:`BaseServer`): Customized federated learning server.
        """
        self.registered_server = True
        self._server_class = server

    def register_client(self, client):
        """Register a customized federated learning client.

        Args:
            client (:obj:`BaseClient`): Customized federated learning client.
        """
        self.registered_client = True
        self._client_class = client

    def print_(self, content):
        """Log the content only when the server is primary server.

        Args:
            content (str): The content to log.
        """
        if self._is_primary_server():
            logger.info(content)

    def _is_primary_server(self):
        """Check whether current running server is the primary server.

        In standalone or remote training, the server is primary.
        In distributed training, the server on `rank0` is primary.
        """
        return not self.conf.is_distributed or self.conf.distributed.rank == 0


def _set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


# Initialize the global coordinator object
_global_coord = Coordinator()


def init_conf(conf=None):
    """Initialize configuration for COALA. It overrides and supplements default configuration loaded from config.yaml
    with the provided configurations.

    Args:
        conf (dict): Configurations.

    Returns:
        omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf.
    """
    here = path.abspath(path.dirname(__file__))
    config_file = path.join(here, 'config.yaml')
    return load_config(config_file, conf)


[docs]def load_config(file, conf=None): """Load and merge configuration from file and input Args: file (str): filename of the configuration. conf (dict): Configurations. Returns: omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf. """ config = OmegaConf.load(file) if conf is not None: config = OmegaConf.merge(config, conf) return config
def init_logger(log_level): """Initialize internal logger of COALA. Args: log_level (int): Logger level, e.g., logging.INFO, logging.DEBUG """ log_formatter = logging.Formatter("%(asctime)s [%(threadName)s] [%(levelname)-5.5s] %(message)s") root_logger = logging.getLogger() log_level = logging.INFO if not log_level else log_level root_logger.setLevel(log_level) file_path = os.path.join(os.getcwd(), "logs") if not os.path.exists(file_path): os.makedirs(file_path) file_path = path.join(file_path, "train" + time.strftime(".%m_%d_%H_%M_%S") + ".log") file_handler = logging.FileHandler(file_path) file_handler.setFormatter(log_formatter) root_logger.addHandler(file_handler) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(log_formatter) root_logger.addHandler(console_handler)
[docs]def init(conf=None, init_all=True): """Initialize COALA. Args: conf (dict, optional): Configurations. init_all (bool, optional): Whether initialize dataset, model, server, and client other than configuration. """ global _global_coord config = init_conf(conf) init_logger(config.tracking.log_level) _set_random_seed(config.seed) _global_coord.init(config, init_all)
[docs]def run(): """Run federated learning process.""" global _global_coord _global_coord.run()
[docs]def init_dataset(): """Initialize dataset, either using registered dataset or out-of-the-box datasets set in config.""" global _global_coord _global_coord.init_dataset()
[docs]def init_model(): """Initialize model, either using registered model or out-of–the-box model set in config. Returns: nn.Module: Model used in federated learning. """ global _global_coord _global_coord.init_model() return _global_coord.model
[docs]def start_server(args=None): """Start federated learning server service for remote training. Args: args (argparse.Namespace): Configurations passed in as arguments. """ global _global_coord _global_coord.start_server(args)
[docs]def start_client(args=None): """Start federated learning client service for remote training. Args: args (argparse.Namespace): Configurations passed in as arguments. """ global _global_coord _global_coord.start_client(args)
def get_coordinator(): """Get the global coordinator instance. Returns: :obj:`Coordinator`: global coordinator instance. """ return _global_coord
[docs]def register_dataset(train_data, test_data, val_data=None): """Register datasets for federated learning training. Args: train_data (:obj:`FederatedDataset`): Training dataset. test_data (:obj:`FederatedDataset`): Testing dataset. val_data (:obj:`FederatedDataset`): Validation dataset. """ global _global_coord _global_coord.register_dataset(train_data, test_data, val_data)
[docs]def register_model(model): """Register model for federated learning training. Args: model (nn.Module): PyTorch model, both class and instance are acceptable. """ global _global_coord _global_coord.register_model(model)
[docs]def register_server(server): """Register federated learning server. Args: server (:obj:`BaseServer`): Customized federated learning server. """ global _global_coord _global_coord.register_server(server)
[docs]def register_client(client): """Register federated learning client. Args: client (:obj:`BaseClient`): Customized federated learning client. """ global _global_coord _global_coord.register_client(client)