HUX布斯
如果你的 's 变化不大,并且你想对代码进行矢量化,你可以先取每行的最大顶部,然后收集所需的结果。kk# Code from OPimport torchelements = torch.rand(5,10)topk_list = [2,3,1,2,0] # means top2 for 1st row, top3 for 2nd row, top1 for 3rd row,....index_list = [] # record the topk index in elementsfor i in range(5): index_list.append(elements[i].topk(topk_list[i]))# Print the resultprint(index_list)# Get topk for max_kmax_k = max(topk_list)topk_vals, topk_inds = elements.topk(max_k, dim=-1)# Select desired topk using maskmask = torch.arange(max_k)[None, :] < torch.tensor(topk_list)[:, None]vals, inds = topk_vals[mask], topk_inds[mask]rows, _ = mask.nonzero().Tprint("-" * 10)print("rows", rows)print("inds", inds)print("vals", vals)# Or splitvals_per_row = vals.split(topk_list)inds_per_row = inds.split(topk_list)print("-" * 10)print("vals_per_row", vals_per_row)print("inds_per_row", inds_per_row)# Or zip (for loop but should be cheap)index_list = zip(vals_per_row, inds_per_row)print("-" * 10)print("zipped results", list(index_list))这将给出以下输出:[torch.return_types.topk(values=tensor([0.8148, 0.7443]),indices=tensor([8, 4])), torch.return_types.topk(values=tensor([0.7529, 0.7352, 0.6354]),indices=tensor([8, 1, 9])), torch.return_types.topk(values=tensor([0.8792]),indices=tensor([7])), torch.return_types.topk(values=tensor([0.9626, 0.8728]),indices=tensor([6, 2])), torch.return_types.topk(values=tensor([]),indices=tensor([], dtype=torch.int64))]----------rows tensor([0, 0, 1, 1, 1, 2, 3, 3])inds tensor([8, 4, 8, 1, 9, 7, 6, 2])vals tensor([0.8148, 0.7443, 0.7529, 0.7352, 0.6354, 0.8792, 0.9626, 0.8728])----------vals_per_row (tensor([0.8148, 0.7443]), tensor([0.7529, 0.7352, 0.6354]), tensor([0.8792]), tensor([0.9626, 0.8728]), tensor([]))inds_per_row (tensor([8, 4]), tensor([8, 1, 9]), tensor([7]), tensor([6, 2]), tensor([], dtype=torch.int64))----------zipped results [(tensor([0.8148, 0.7443]), tensor([8, 4])), (tensor([0.7529, 0.7352, 0.6354]), tensor([8, 1, 9])), (tensor([0.8792]), tensor([7])), (tensor([0.9626, 0.8728]), tensor([6, 2])), (tensor([]), tensor([], dtype=torch.int64))]