猿问

我无法理解这行代码是做什么的?

这部分课程我不明白这段代码做了什么:


for file in os.listdir(path):

    if(os.path.isfile(os.path.join(path,file)) and select in file): 

        temp = scipy.io.loadmat(os.path.join(path,file))

        temp = {k:v for k, v in temp.items() if k[0] != '_'}

        for i  in range(len(temp[patch_type+"_patches"])):

            self.tensors.append(temp[patch_type+"_patches"][i])

            self.labels.append(temp[patch_type+"_labels"][0][i])


self.tensors = np.array(self.tensors)

self.labels = np.array(self.labels)

尤其是这一行:


temp = {k:v for k, v in temp.items() if k[0] != '_'}

全班如下:


class Datasets(Dataset):

    def __init__(self,path,train,transform=None):

        if(train):

            select ="Training"

            patch_type = "train"

        else:

            select = "Testing"

            patch_type = "testing"


        self.tensors = []

        self.labels = []

        self.transform = transform



        for file in os.listdir(path):

            if(os.path.isfile(os.path.join(path,file)) and select in file): 


                temp = scipy.io.loadmat(os.path.join(path,file))

                temp = {k:v for k, v in temp.items() if k[0] != '_'}

                for i  in range(len(temp[patch_type+"_patches"])):

                    self.tensors.append(temp[patch_type+"_patches"][i])

                    self.labels.append(temp[patch_type+"_labels"][0][i])


        self.tensors = np.array(self.tensors)

        self.labels = np.array(self.labels)


    def __len__(self):

        try:

            if len(self.tensors) != len(self.labels):

                raise Exception("Lengths of the tensor and labels list are not the same")

        except Exception as e:

            print(e.args[0])

        return len(self.tensors)


    def __getitem__(self,idx):

        sample = (self.tensors[idx],self.labels[idx])

       # print(self.labels)

        sample = (torch.from_numpy(self.tensors[idx]),torch.from_numpy(np.array(self.labels[idx])).long())

        return sample

    #tuple containing the image patch and its corresponding label


ibeautiful
浏览 261回答 2
2回答

凤凰求蛊

这是一个字典理解;在这种特殊情况下,它dict从现有的 dict创建一个新的temp,但仅适用于键k不以下划线开头的项目。该检查由if ...零件执行。它相当于new = {}for k, v in temp.items():    if key[0] != '_':        new[k] = valuetemp = new或者,略有不同:new = {}for key, value in temp.items():    if not key.startswith('_'):        new[key] = valuetemp = new您可以看到它作为单行看起来更好一些,因为它避免了临时 dict (new; 在幕后,它仍然创建了一个无名的临时 dict )。
随时随地看视频慕课网APP

相关分类

Python
我要回答