[docs]deffederated_averaging(models,weights):"""Compute weighted average of model parameters and persistent buffers. Using state_dict of model, including persistent buffers like BN stats. Args: models (list[nn.Module]): List of models to average. weights (list[float]): List of weights, corresponding to each model. Weights are dataset size of clients by default. Returns nn.Module: Weighted averaged model. """ifmodels==[]:returnNoneifweights==[]orsum(weights)==0:weights=[1for_inrange(len(models))]model,total_weights=weighted_sum(models,weights)model_params=model.state_dict()withtorch.no_grad():forname,paramsinmodel_params.items():model_params[name]=torch.div(params,total_weights)model.load_state_dict(model_params)returnmodel
[docs]deffederated_averaging_only_params(models,weights):"""Compute weighted average of model parameters. Use model parameters only. Args: models (list[nn.Module]): List of models to average. weights (list[float]): List of weights, corresponding to each model. Weights are dataset size of clients by default. Returns nn.Module: Weighted averaged model. """ifmodels==[]:returnNoneifweights==[]orsum(weights)==0:weights=[1for_inrange(len(models))]model,total_weights=weighted_sum_only_params(models,weights)model_params=dict(model.named_parameters())withtorch.no_grad():forname,paramsinmodel_params.items():model_params[name].set_(model_params[name]/total_weights)returnmodel
[docs]defweighted_sum(models,weights):"""Compute weighted sum of model parameters and persistent buffers. Using state_dict of model, including persistent buffers like BN stats. Args: models (list[nn.Module]): List of models to average. weights (list[float]): List of weights, corresponding to each model. Weights are dataset size of clients by default. Returns nn.Module: Weighted averaged model. float: Sum of weights. """ifmodels==[]orweights==[]:returnNone,0weights_sum=sum(weights)ifweights_sum==0:# In multiple GPU scenario, this is for aggregation within a GPU.# The returned sum of weights is still 0 to ignore the weights from this GPU.weights=[1for_inrange(len(models))]model=copy.deepcopy(models[0])model_sum_params=copy.deepcopy(models[0].state_dict())withtorch.no_grad():forname,paramsinmodel_sum_params.items():params*=weights[0]foriinrange(1,len(models)):model_params=dict(models[i].state_dict())params+=model_params[name]*weights[i]model_sum_params[name]=paramsmodel.load_state_dict(model_sum_params)returnmodel,weights_sum
[docs]defweighted_sum_only_params(models,weights):"""Compute weighted sum of model parameters. Use model parameters only. Args: models (list[nn.Module]): List of models to average. weights (list[float]): List of weights, corresponding to each model. Weights are dataset size of clients by default. Returns nn.Module: Weighted averaged model. float: Sum of weights. """ifmodels==[]orweights==[]:returnNone,0weights_sum=sum(weights)ifweights_sum==0:# In multiple GPU scenario, this is for aggregation within a GPU.# The returned sum of weights is still 0 to ignore the weights from this GPU. weights=[1for_inrange(len(models))]model_sum=copy.deepcopy(models[0])model_sum_params=dict(model_sum.named_parameters())withtorch.no_grad():forname,paramsinmodel_sum_params.items():params*=weights[0]foriinrange(1,len(models)):model_params=dict(models[i].named_parameters())params+=model_params[name]*weights[i]model_sum_params[name].set_(params)returnmodel_sum,weights_sum