如何以支持 autograd 的方式围绕其中心旋转 PyTorch 图像张量?

我想围绕其中心随机旋转图像张量(B、C、H、W)(我认为是二维旋转?)。我想避免使用 NumPy 和 Kornia,这样我基本上只需要从 torch 模块导入。我也没有使用torchvision.transforms,因为我需要它与 autograd 兼容。本质上,我正在尝试为 DeepDream 等可视化技术创建一个 autograd 兼容版本torchvision.transforms.RandomRotation()(因此我需要尽可能避免伪影)。

import torch

import math

import random

import torchvision.transforms as transforms

from PIL import Image



# Load image

def preprocess_simple(image_name, image_size):

    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])

    image = Image.open(image_name).convert('RGB')

    return Loader(image).unsqueeze(0)

    

# Save image   

def deprocess_simple(output_tensor, output_name):

    output_tensor.clamp_(0, 1)

    Image2PIL = transforms.ToPILImage()

    image = Image2PIL(output_tensor.squeeze(0))

    image.save(output_name)



# Somehow rotate tensor around it's center

def rotate_tensor(tensor, radians):

    ...

    return rotated_tensor


# Get a random angle within a specified range 

r_degrees = 5

angle_range = list(range(-r_degrees, r_degrees))

n = random.randint(angle_range[0], angle_range[len(angle_range)-1])


# Convert angle from degrees to radians

ang_rad = angle * math.pi / 180



# test_tensor = preprocess_simple('path/to/file', (512,512))

test_tensor = torch.randn(1,3,512,512)



# Rotate input tensor somehow

output_tensor = rotate_tensor(test_tensor, ang_rad)



# Optionally use this to check rotated image

# deprocess_simple(output_tensor, 'rotated_image.jpg')

我想要完成的一些示例输出:

https://img2.mukewang.com/651f794d0001060f01850091.jpg

米琪卡哇伊
浏览 103回答 3
3回答

神不在的星期二

因此,网格生成器和采样器是 Spatial Transformer 的子模块(JADERBERG、Max 等人)。这些子模块不可训练,它们可让您应用可学习的以及不可学习的空间变换。theta在这里,我使用这两个子模块,并使用 PyTorch 的函数torch.nn.functional.affine_grid和(这些函数分别是生成器和采样器的实现)来旋转图像torch.nn.functional.affine_sample:import torchimport torch.nn.functional as Fimport numpy as npimport matplotlib.pyplot as pltdef get_rot_mat(theta):    theta = torch.tensor(theta)    return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],                         [torch.sin(theta), torch.cos(theta), 0]])def rot_img(x, theta, dtype):    rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)    grid = F.affine_grid(rot_mat, x.size()).type(dtype)    x = F.grid_sample(x, grid)    return x#Test:dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor#im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x Wplt.figure()#Rotation by np.pi/2 with autograd support:rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)在上面的示例中,假设我们将图像im视为一只穿着裙子跳舞的猫:rotated_im将是一只穿着裙子逆时针旋转 90 度的跳舞猫:如果我们用rot_img等号theta调用,就会得到以下结果np.pi/4: 最好的部分是它可以区分输入并具有 autograd 支持!万岁!

森林海

使用 torchvision 应该很简单:import torchvision.transforms.functional as TFangle = 30x = torch.randn(1,3,512,512)out = TF.rotate(x, angle)例如如果x是:out旋转 30 度为(注:逆时针):

慕姐8265434

pytorch 有一个函数:x = torch.tensor([[0, 1],             [2, 3]]) x = torch.rot90(x, 1, [0, 1])>> tensor([[1, 3],            [0, 2]])以下是文档:https://pytorch.org/docs/stable/ generated/torch.rot90.html
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python