慕村225694
您可以利用 pytorch 的稀疏数据类型:class SparseLinear(nn.Module): def __init__(self, in_features, out_features, sparse_indices): super(SparseLinear, self).__init__() self.weight = nn.Parameter(data=torch.sparse.FloatTensor(sparse_indices, torch.randn(sparse_indices.shape[1]), [in_features, out_features]), requires_grad=True) self.bias = nn.Parameter(data=torch.randn(out_features), requires_grad=True) def forward(self, x): return torch.sparse.admm(self.bias, self.weight, x, 1., 1.)