Skip to content

Commit

Permalink
adm_dwt2_cy.pyx: fix cython warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
nilfm99 committed Jan 29, 2024
1 parent b6548d4 commit 0fea102
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions python/vmaf/core/adm_dwt2_cy.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# import sys
# The following line is needed to suppress a warning due to numpy's backward compatibility issues
# See more: https://github.com/numpy/numpy/issues/21865
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION

import numpy as np
cimport numpy as np
Expand All @@ -7,20 +9,19 @@ from libc.stdlib cimport calloc, free

from vmaf.core.adm_dwt2_tools import ALIGN_CEIL, MAX_ALIGN

ctypedef double np_float

cdef struct adm_dwt_band_t:
np_float *band_a
np_float *band_v
np_float *band_h
np_float *band_d
cdef extern from "feature/adm_tools.h":
ctypedef struct adm_dwt_band_t_d:
double *band_a
double *band_v
double *band_h
double *band_d

cdef extern from "../../../libvmaf/src/feature/adm_tools.c":
void dwt2_src_indices_filt_s(int **src_ind_y, int **src_ind_x, int w, int h)
void adm_dwt2_d(const np_float *src, const adm_dwt_band_t *dst, int **ind_y, int **ind_x, int w, int h, int src_stride, int dst_stride)
void adm_dwt2_d(const double *src, const adm_dwt_band_t_d *dst, int **ind_y, int **ind_x, int w, int h, int src_stride, int dst_stride)

cdef extern from "../../../libvmaf/src/feature/adm.c":
char *init_dwt_band_d(adm_dwt_band_t *band, char *data_top, size_t buf_sz_one)
char *init_dwt_band_d(adm_dwt_band_t_d *band, char *data_top, size_t buf_sz_one)

cdef extern from "../../../libvmaf/src/mem.c":
void *aligned_malloc(size_t size, size_t alignment)
Expand All @@ -32,13 +33,13 @@ cdef extern from "../../../libvmaf/src/feature/offset.c":
def adm_dwt2_cy(np.ndarray[np.float64_t, ndim=2, mode='c'] a):

cdef np.ndarray[np.float64_t, ndim=2, mode='c'] a_buf = np.ascontiguousarray(a, dtype=np.float64)
cdef np_float *aa = <np_float*> a_buf.data # aa: curr_ref_scale
cdef double *aa = <double*> a_buf.data # aa: curr_ref_scale

cdef int h = len(a)
cdef int w = len(a[0])
cdef int h_new, w_new

cdef np_float *data_buf = NULL
cdef double *data_buf = NULL
cdef char *data_top

cdef char *ind_buf_y = NULL
Expand All @@ -48,18 +49,18 @@ def adm_dwt2_cy(np.ndarray[np.float64_t, ndim=2, mode='c'] a):
cdef int *ind_y[4]
cdef int *ind_x[4]

cdef adm_dwt_band_t aa_dwt2 # aa_dwt2: ref_dwt2
cdef adm_dwt_band_t_d aa_dwt2 # aa_dwt2: ref_dwt2

cdef int curr_ref_stride = w * sizeof(np_float)
cdef int buf_stride = ALIGN_CEIL(((w + 1) // 2) * sizeof(np_float))
cdef int curr_ref_stride = w * sizeof(double)
cdef int buf_stride = ALIGN_CEIL(((w + 1) // 2) * sizeof(double))
cdef size_t buf_sz_one = <size_t> buf_stride * ((h + 1) // 2)

cdef int ind_size_y = ALIGN_CEIL(((h + 1) // 2) * sizeof(int))
cdef int ind_size_x = ALIGN_CEIL(((w + 1) // 2) * sizeof(int))

# == # must use calloc to initialize mem to 0: adm_dwt2_s doesn't touch every cell for small w and h ==
# data_buf = <np_float *> aligned_malloc(buf_sz_one * 4, MAX_ALIGN)
data_buf = <np_float *> calloc(buf_sz_one * 4, 1)
# data_buf = <double *> aligned_malloc(buf_sz_one * 4, MAX_ALIGN)
data_buf = <double *> calloc(buf_sz_one * 4, 1)
if not data_buf:
free(data_buf)
aligned_free(buf_y_orig)
Expand Down Expand Up @@ -101,11 +102,11 @@ def adm_dwt2_cy(np.ndarray[np.float64_t, ndim=2, mode='c'] a):
w_new = (w + 1) // 2
h_new = (h + 1) // 2

w_new_strided = ALIGN_CEIL(w_new * sizeof(np_float)) // sizeof(np_float)
w_new_strided = ALIGN_CEIL(w_new * sizeof(double)) // sizeof(double)

# # # ====== debug ======
# print("h={}, w={}, aa[0]={}, aa[1]={}, aa[2]={}".format(h, w, aa[0], aa[1], aa[2]))
# print("sizeof(np_float)={}".format(sizeof(np_float)))
# print("sizeof(double)={}".format(sizeof(double)))
# print("curr_ref_stride={}, buf_stride={}, buf_sz_one={}".format(curr_ref_stride, buf_stride, buf_sz_one))
# print("ind_size_y={}, ind_size_x={}".format(ind_size_y, ind_size_x))
# print("h_new={}, w_new={}".format(h_new, w_new))
Expand Down

0 comments on commit 0fea102

Please sign in to comment.