diff --git a/libkmod/libkmod-file-zstd.c b/libkmod/libkmod-file-zstd.c index 220d77ab..5bccf761 100644 --- a/libkmod/libkmod-file-zstd.c +++ b/libkmod/libkmod-file-zstd.c @@ -3,6 +3,9 @@ * Copyright © 2024 Intel Corporation */ +/* TODO: replace with build system define once supported */ +#define DLSYM_LOCALLY_ENABLED 0 + #include #include #include @@ -19,6 +22,24 @@ #include "libkmod-internal.h" #include "libkmod-internal-file.h" +#define DL_SYMBOL_TABLE(M) \ + M(ZSTD_decompress) \ + M(ZSTD_getErrorName) \ + M(ZSTD_getFrameContentSize) \ + M(ZSTD_isError) + +DL_SYMBOL_TABLE(DECLARE_SYM) + +static int dlopen_zstd(void) +{ + static void *dl = NULL; + + if (!DLSYM_LOCALLY_ENABLED) + return 0; + + return dlsym_many(&dl, "libzstd.so.1", DL_SYMBOL_TABLE(DLSYM_ARG) NULL); +} + int kmod_file_load_zstd(struct kmod_file *file) { void *src_buf = MAP_FAILED, *dst_buf = NULL; @@ -27,6 +48,13 @@ int kmod_file_load_zstd(struct kmod_file *file) struct stat st; int ret; + ret = dlopen_zstd(); + if (ret < 0) { + ERR(file->ctx, "zstd: can't load and resolve symbols (%s)", + strerror(-ret)); + return -EINVAL; + } + if (fstat(file->fd, &st) < 0) { ret = -errno; ERR(file->ctx, "zstd: %m\n"); @@ -45,7 +73,7 @@ int kmod_file_load_zstd(struct kmod_file *file) goto out; } - frame_size = ZSTD_getFrameContentSize(src_buf, src_size); + frame_size = sym_ZSTD_getFrameContentSize(src_buf, src_size); if (frame_size == 0 || frame_size == ZSTD_CONTENTSIZE_UNKNOWN || frame_size == ZSTD_CONTENTSIZE_ERROR) { ret = -EINVAL; @@ -65,9 +93,9 @@ int kmod_file_load_zstd(struct kmod_file *file) goto out; } - dst_size = ZSTD_decompress(dst_buf, dst_size, src_buf, src_size); - if (ZSTD_isError(dst_size)) { - ERR(file->ctx, "zstd: %s\n", ZSTD_getErrorName(dst_size)); + dst_size = sym_ZSTD_decompress(dst_buf, dst_size, src_buf, src_size); + if (sym_ZSTD_isError(dst_size)) { + ERR(file->ctx, "zstd: %s\n", sym_ZSTD_getErrorName(dst_size)); ret = -EINVAL; goto out; }