Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non-square render resolutions (width != height) #96

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/demo_deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main():
model = Model(args.template_mesh).cuda()
transform = sr.LookAt(viewing_angle=15)
lighting = sr.Lighting()
rasterizer = sr.SoftRasterizer(image_size=64, sigma_val=1e-4, aggr_func_rgb='hard')
rasterizer = sr.SoftRasterizer(image_size=(64, 64), sigma_val=1e-4, aggr_func_rgb='hard')

# read training images and camera poses
images = np.load(args.filename_input).astype('float32') / 255.
Expand Down
33 changes: 18 additions & 15 deletions soft_renderer/cuda/soft_rasterize_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ std::vector<at::Tensor> forward_soft_rasterize_cuda(
at::Tensor faces_info,
at::Tensor aggrs_info,
at::Tensor soft_colors,
int image_size,
int image_width,
int image_height,
float near,
float far,
float eps,
Expand All @@ -36,7 +37,8 @@ std::vector<at::Tensor> backward_soft_rasterize_cuda(
at::Tensor grad_faces,
at::Tensor grad_textures,
at::Tensor grad_soft_colors,
int image_size,
int image_width,
int image_height,
float near,
float far,
float eps,
Expand All @@ -55,14 +57,14 @@ std::vector<at::Tensor> backward_soft_rasterize_cuda(
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)


std::vector<at::Tensor> forward_soft_rasterize(
at::Tensor faces,
at::Tensor textures,
at::Tensor faces_info,
at::Tensor aggrs_info,
at::Tensor soft_colors,
int image_size,
int image_width,
int image_height,
float near,
float far,
float eps,
Expand All @@ -81,12 +83,12 @@ std::vector<at::Tensor> forward_soft_rasterize(
CHECK_INPUT(aggrs_info);
CHECK_INPUT(soft_colors);

return forward_soft_rasterize_cuda(faces, textures,
faces_info, aggrs_info,
soft_colors,
image_size, near, far, eps,
return forward_soft_rasterize_cuda(faces, textures,
faces_info, aggrs_info,
soft_colors,
image_width, image_height, near, far, eps,
sigma_val, func_id_dist, dist_eps,
gamma_val, func_id_rgb, func_id_alpha,
gamma_val, func_id_rgb, func_id_alpha,
texture_sample_type, double_side);
}

Expand All @@ -100,7 +102,8 @@ std::vector<at::Tensor> backward_soft_rasterize(
at::Tensor grad_faces,
at::Tensor grad_textures,
at::Tensor grad_soft_colors,
int image_size,
int image_width,
int image_height,
float near,
float far,
float eps,
Expand All @@ -122,12 +125,12 @@ std::vector<at::Tensor> backward_soft_rasterize(
CHECK_INPUT(grad_textures);
CHECK_INPUT(grad_soft_colors);

return backward_soft_rasterize_cuda(faces, textures, soft_colors,
faces_info, aggrs_info,
grad_faces, grad_textures, grad_soft_colors,
image_size, near, far, eps,
return backward_soft_rasterize_cuda(faces, textures, soft_colors,
faces_info, aggrs_info,
grad_faces, grad_textures, grad_soft_colors,
image_width, image_height, near, far, eps,
sigma_val, func_id_dist, dist_eps,
gamma_val, func_id_rgb, func_id_alpha,
gamma_val, func_id_rgb, func_id_alpha,
texture_sample_type, double_side);
}

Expand Down
98 changes: 54 additions & 44 deletions soft_renderer/cuda/soft_rasterize_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ __global__ void forward_soft_rasterize_inv_cuda_kernel(
scalar_t* faces_info,
int batch_size,
int num_faces,
int image_size) {
int image_width,
int image_height) {
/* batch number, face, number, image size, face[v012][RGB] */
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size * num_faces) {
Expand Down Expand Up @@ -290,7 +291,8 @@ __global__ void forward_soft_rasterize_cuda_kernel(
scalar_t* soft_colors,
int batch_size,
int num_faces,
int image_size,
int image_width,
int image_height,
int texture_size,
int texture_res,
float near,
Expand All @@ -309,17 +311,18 @@ __global__ void forward_soft_rasterize_cuda_kernel(
////////////////////////

const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size * image_size * image_size) {
if (i >= batch_size * image_width * image_height) {
return;
}
const int is = image_size;
const int iw = image_width;
const int ih = image_height;
const int nf = num_faces;
const int bn = i / (is * is);
const int pn = i % (is * is);
const int yi = is - 1 - (pn / is);
const int xi = pn % is;
const scalar_t yp = (2. * yi + 1. - is) / is;
const scalar_t xp = (2. * xi + 1. - is) / is;
const int bn = i / (iw * ih);
const int pn = i % (iw * ih);
const int yi = ih - 1 - (pn / iw);
const int xi = pn % iw;
const scalar_t yp = (2. * yi + 1. - ih) / ih;
const scalar_t xp = (2. * xi + 1. - iw) / iw;

const scalar_t *face = &faces[bn * nf * 9] - 9;
const scalar_t *texture = &textures[bn * nf * texture_size * 3] - texture_size * 3;
Expand All @@ -334,10 +337,10 @@ __global__ void forward_soft_rasterize_cuda_kernel(
scalar_t softmax_max = eps;
for (int k = 0; k < 3; k++) {
if (func_id_rgb == 0) { // hard assign, set to background
soft_color[k] = soft_colors[(bn * 4 + k) * (is * is) + pn];
soft_color[k] = soft_colors[(bn * 4 + k) * (iw * ih) + pn];
} else
if (func_id_rgb == 1) {
soft_color[k] = soft_colors[(bn * 4 + k) * (is * is) + pn] * softmax_sum; // initialize background color
soft_color[k] = soft_colors[(bn * 4 + k) * (iw * ih) + pn] * softmax_sum; // initialize background color
}
}
scalar_t depth_min = 10000000;
Expand Down Expand Up @@ -432,29 +435,29 @@ __global__ void forward_soft_rasterize_cuda_kernel(

// finalize aggregation
if (func_id_alpha == 0) {
soft_colors[(bn * 4 + 3) * (is * is) + pn] = soft_color[3];
soft_colors[(bn * 4 + 3) * (iw * ih) + pn] = soft_color[3];
} else
if (func_id_alpha == 1) {
soft_colors[(bn * 4 + 3) * (is * is) + pn] = soft_color[3] / nf;
soft_colors[(bn * 4 + 3) * (iw * ih) + pn] = soft_color[3] / nf;
} else
if (func_id_alpha == 2) {
soft_colors[(bn * 4 + 3) * (is * is) + pn] = 1. - soft_color[3];
soft_colors[(bn * 4 + 3) * (iw * ih) + pn] = 1. - soft_color[3];
}

if (func_id_rgb == 0) {
if (face_index_min != -1)
for (int k = 0; k < 3; k++) {
soft_colors[(bn * 4 + k) * (is * is) + pn] = soft_color[k];
soft_colors[(bn * 4 + k) * (iw * ih) + pn] = soft_color[k];
}
aggrs_info[(bn * 2 + 0) * (is * is) + pn] = depth_min;
aggrs_info[(bn * 2 + 1) * (is * is) + pn] = face_index_min;
aggrs_info[(bn * 2 + 0) * (iw * ih) + pn] = depth_min;
aggrs_info[(bn * 2 + 1) * (iw * ih) + pn] = face_index_min;
} else
if (func_id_rgb == 1) {
for (int k = 0; k < 3; k++) {
soft_colors[(bn * 4 + k) * (is * is) + pn] = soft_color[k] / softmax_sum;
soft_colors[(bn * 4 + k) * (iw * ih) + pn] = soft_color[k] / softmax_sum;
}
aggrs_info[(bn * 2 + 0) * (is * is) + pn] = softmax_sum;
aggrs_info[(bn * 2 + 1) * (is * is) + pn] = softmax_max;
aggrs_info[(bn * 2 + 0) * (iw * ih) + pn] = softmax_sum;
aggrs_info[(bn * 2 + 1) * (iw * ih) + pn] = softmax_max;
}
}

Expand All @@ -471,7 +474,8 @@ __global__ void backward_soft_rasterize_cuda_kernel(
scalar_t* grad_soft_colors,
int batch_size,
int num_faces,
int image_size,
int image_width,
int image_height,
int texture_size,
int texture_res,
float near,
Expand All @@ -490,26 +494,27 @@ __global__ void backward_soft_rasterize_cuda_kernel(
////////////////////////

const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size * image_size * image_size) {
if (i >= batch_size * image_width * image_height) {
return;
}
const int is = image_size;
const int iw = image_width;
const int ih = image_height;
const int nf = num_faces;
const int bn = i / (is * is);
const int pn = i % (is * is);
const int yi = is - 1 - (pn / is);
const int xi = pn % is;
const scalar_t yp = (2. * yi + 1 - is) / is;
const scalar_t xp = (2. * xi + 1 - is) / is;
const int bn = i / (iw * ih);
const int pn = i % (iw * ih);
const int yi = ih - 1 - (pn / iw);
const int xi = pn % iw;
const scalar_t yp = (2. * yi + 1 - ih) / ih;
const scalar_t xp = (2. * xi + 1 - iw) / iw;

const scalar_t* face = &faces[bn * nf * 9] - 9;
const scalar_t* texture = &textures[bn * nf * texture_size * 3] - texture_size * 3;
const scalar_t* face_info = &faces_info[bn * nf * 27] - 27;

const scalar_t threshold = dist_eps * sigma_val;

const scalar_t softmax_sum = aggrs_info[(bn * 2 + 0) * (is * is) + pn];
const scalar_t softmax_max = aggrs_info[(bn * 2 + 1) * (is * is) + pn];
const scalar_t softmax_sum = aggrs_info[(bn * 2 + 0) * (iw * ih) + pn];
const scalar_t softmax_max = aggrs_info[(bn * 2 + 1) * (iw * ih) + pn];

for (int fn = 0; fn < nf; fn++) {
face += 9;
Expand Down Expand Up @@ -556,15 +561,15 @@ __global__ void backward_soft_rasterize_cuda_kernel(
/////////////////////////////////////////////////////

// aggragate for alpha channel
scalar_t C_grad_xy_alpha = grad_soft_colors[(bn * 4 + 3) * (is * is) + pn];
scalar_t C_grad_xy_alpha = grad_soft_colors[(bn * 4 + 3) * (iw * ih) + pn];
if (func_id_alpha == 0) { // hard assign
// hard assign alpha channels does not have gradient
} else
if (func_id_alpha == 1) { // Sum
C_grad_xy_alpha /= nf;
} else
if (func_id_alpha == 2) { // Logical-Or
C_grad_xy_alpha *= (1 - soft_colors[(bn * 4 + 3) * (is * is) + pn]) / max(1 - soft_fragment, 1e-6);
C_grad_xy_alpha *= (1 - soft_colors[(bn * 4 + 3) * (iw * ih) + pn]) / max(1 - soft_fragment, 1e-6);
}
C_grad_xy += C_grad_xy_alpha;

Expand All @@ -579,7 +584,7 @@ __global__ void backward_soft_rasterize_cuda_kernel(
if (fn == softmax_max) {
for (int k = 0; k < 3; k++) {
for (int j = 0; j < texture_size; j++) {
atomicAdd(&grad_texture[3 * j + k], backward_sample_texture(grad_soft_colors[(bn * 4 + k) * (is * is) + pn], w, texture_res, j, texture_sample_type));
atomicAdd(&grad_texture[3 * j + k], backward_sample_texture(grad_soft_colors[(bn * 4 + k) * (iw * ih) + pn], w, texture_res, j, texture_sample_type));
}
}
}
Expand All @@ -591,15 +596,15 @@ __global__ void backward_soft_rasterize_cuda_kernel(
const scalar_t zp_softmax = soft_fragment * exp((zp_norm - softmax_max) / gamma_val) / softmax_sum;

for (int k = 0; k < 3; k++) {
const scalar_t grad_soft_color_k = grad_soft_colors[(bn * 4 + k) * (is * is) + pn];
const scalar_t grad_soft_color_k = grad_soft_colors[(bn * 4 + k) * (iw * ih) + pn];

for (int j = 0; j < texture_size; j++) {
const scalar_t grad_t = backward_sample_texture(grad_soft_color_k, w, texture_res, j, texture_sample_type);
atomicAdd(&grad_texture[3 * j + k], zp_softmax * grad_t);
}

const scalar_t color_k = forward_sample_texture(texture, w, texture_res, k, texture_sample_type);
C_grad_xyz_rgb += grad_soft_color_k * (color_k - soft_colors[(bn * 4 + k) * (is * is) + pn]);
C_grad_xyz_rgb += grad_soft_color_k * (color_k - soft_colors[(bn * 4 + k) * (iw * ih) + pn]);
}
C_grad_xyz_rgb *= zp_softmax;
C_grad_xy += C_grad_xyz_rgb / soft_fragment;
Expand Down Expand Up @@ -648,7 +653,8 @@ std::vector<at::Tensor> forward_soft_rasterize_cuda(
at::Tensor faces_info,
at::Tensor aggrs_info,
at::Tensor soft_colors,
int image_size,
int image_width,
int image_height,
float near,
float far,
float eps,
Expand All @@ -674,14 +680,15 @@ std::vector<at::Tensor> forward_soft_rasterize_cuda(
faces_info.data<scalar_t>(),
batch_size,
num_faces,
image_size);
image_width,
image_height);
}));

cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
printf("Error in forward_transform_inv_triangle: %s\n", cudaGetErrorString(err));

const dim3 blocks_2 ((batch_size * image_size * image_size - 1) / threads +1);
const dim3 blocks_2 ((batch_size * image_width * image_height - 1) / threads +1);

AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_eff_soft_rasterize_cuda", ([&] {
forward_soft_rasterize_cuda_kernel<scalar_t><<<blocks_2, threads>>>(
Expand All @@ -692,7 +699,8 @@ std::vector<at::Tensor> forward_soft_rasterize_cuda(
soft_colors.data<scalar_t>(),
batch_size,
num_faces,
image_size,
image_width,
image_height,
texture_size,
texture_res,
near,
Expand Down Expand Up @@ -725,7 +733,8 @@ std::vector<at::Tensor> backward_soft_rasterize_cuda(
at::Tensor grad_faces,
at::Tensor grad_textures,
at::Tensor grad_soft_colors,
int image_size,
int image_width,
int image_height,
float near,
float far,
float eps,
Expand All @@ -743,7 +752,7 @@ std::vector<at::Tensor> backward_soft_rasterize_cuda(
const auto texture_size = textures.size(2);
const auto texture_res = int(sqrt(texture_size));
const int threads = 512;
const dim3 blocks ((batch_size * image_size * image_size - 1) / threads + 1);
const dim3 blocks ((batch_size * image_width * image_height - 1) / threads + 1);

AT_DISPATCH_FLOATING_TYPES(faces.type(), "backward_soft_rasterize_cuda", ([&] {
backward_soft_rasterize_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand All @@ -757,7 +766,8 @@ std::vector<at::Tensor> backward_soft_rasterize_cuda(
grad_soft_colors.data<scalar_t>(),
batch_size,
num_faces,
image_size,
image_width,
image_height,
texture_size,
texture_res,
near,
Expand Down
Loading