From 250df4cf4cae68ce470e59eda062542b302d2b75 Mon Sep 17 00:00:00 2001 From: Yohan Chatelain Date: Fri, 26 Jul 2024 03:06:17 -0400 Subject: [PATCH] UD is working, fix SR --- .../numerics/random_sum/plot.py | 52 + .../numerics/random_sum/test.c | 23 + .../numerics/random_sum/test.sh | 85 + src/libvfcinstrumentonline/Makefile.am | 2 + .../libVFCInstrumentOnline.cpp | 2271 ++--------------- src/libvfcinstrumentonline/rand.cpp | 334 +++ src/libvfcinstrumentonline/shishua.h | 20 +- src/libvfcinstrumentonline/xoroshiro128+.cpp | 104 + src/libvfcinstrumentonline/xoroshiro128+.hpp | 17 + tests/test_online_instrumentation/test.c | 30 +- verificarlo.in.in | 14 +- 11 files changed, 793 insertions(+), 2159 deletions(-) create mode 100644 experiments/online-instrumentation/numerics/random_sum/plot.py create mode 100644 experiments/online-instrumentation/numerics/random_sum/test.c create mode 100755 experiments/online-instrumentation/numerics/random_sum/test.sh create mode 100644 src/libvfcinstrumentonline/rand.cpp create mode 100644 src/libvfcinstrumentonline/xoroshiro128+.cpp create mode 100644 src/libvfcinstrumentonline/xoroshiro128+.hpp diff --git a/experiments/online-instrumentation/numerics/random_sum/plot.py b/experiments/online-instrumentation/numerics/random_sum/plot.py new file mode 100644 index 00000000..729e53f5 --- /dev/null +++ b/experiments/online-instrumentation/numerics/random_sum/plot.py @@ -0,0 +1,52 @@ +import numpy as np +import plotly.express as px +import pandas as pd +import glob +import argparse +import joblib +from functools import reduce + + +def parse_file(file): + with open(file, "r") as f: + for line in f: + if "msec" in line: + return float(line.split()[0]) + + +def run(): + files = glob.glob("*.perf") + values = {} + for file in files: + # opt_mode_type.perf + opt, mode, type = file.split(".")[0].split("_") + x = parse_file(file) + values["opt"] = values.get("opt", []) + [opt] + values["mode"] = values.get("mode", []) + [mode] + values["type"] = values.get("type", []) + [type] + values["slowdown"] = values.get("slowdown", []) + [x] + return pd.DataFrame(values) + + +def main(): + df = run() + df[df["opt"] == "O0"]["slowdown"] /= 2 # O0 + df[df["opt"] == "O0"]["slowdown"] /= 0.7 # O3 + print(df) + fig = px.bar( + df, x="opt", y="slowdown", color="mode", barmode="group", facet_col="type" + ) + # fig.update_yaxes(title_text="Slowdown") + fig.update_layout( + font=dict( + size=22, + ) + ) + fig.show() + + +if __name__ == "__main__": + main() + + +# /usr/local/bin/verificarlo-c++ -DAT_PER_OPERATOR_HEADERS -DCAFFE2_BUILD_MAIN_LIB -DCPUINFO_SUPPORTED_PLATFORM=1 -DFMT_HEADER_ONLY=1 -DFXDIV_USE_INLINE_ASSEMBLY=0 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNNP_CONVOLUTION_ONLY=0 -DNNP_INFERENCE_ONLY=0 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DUSE_C10D_GLOO -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cpu_EXPORTS -I/var/lib/jenkins/workspace/third_party/pocketfft -I/var/lib/jenkins/workspace/build/aten/src -I/var/lib/jenkins/workspace/aten/src -I/var/lib/jenkins/workspace/build -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/cmake/../third_party/benchmark/include-I/var/lib/jenkins/workspace/third_party/onnx -I/var/lib/jenkins/workspace/build/third_party/onnx -I/var/lib/jenkins/workspace/third_party/foxi -I/var/lib/jenkins/workspace/build/third_party/foxi -I/var/lib/jenkins/workspace/torch/csrc/api -I/var/lib/jenkins/workspace/torch/csrc/api/include -I/var/lib/jenkins/workspace/caffe2/aten/src/TH -I/var/lib/jenkins/workspace/build/caffe2/aten/src/TH -I/var/lib/jenkins/workspace/build/caffe2/aten/src -I/var/lib/jenkins/workspace/build/caffe2/../aten/src -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/third_party/miniz-2.1.0 -I/var/lib/jenkins/workspace/third_party/kineto/libkineto/include -I/var/lib/jenkins/workspace/third_party/kineto/libkineto/src -I/var/lib/jenkins/workspace/aten/src/ATen/.. -I/var/lib/jenkins/workspace/third_party/FXdiv/include -I/var/lib/jenkins/workspace/c10/.. -I/var/lib/jenkins/workspace/third_party/pthreadpool/include -I/var/lib/jenkins/workspace/third_party/cpuinfo/include -I/var/lib/jenkins/workspace/third_party/QNNPACK/include -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/include -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/src -I/var/lib/jenkins/workspace/third_party/cpuinfo/deps/clog/include -I/var/lib/jenkins/workspace/third_party/NNPACK/include -I/var/lib/jenkins/workspace/third_party/fbgemm/include -I/var/lib/jenkins/workspace/third_party/fbgemm -I/var/lib/jenkins/workspace/third_party/fbgemm/third_party/asmjit/src -I/var/lib/jenkins/workspace/third_party/ittapi/src/ittnotify -I/var/lib/jenkins/workspace/third_party/FP16/include -I/var/lib/jenkins/workspace/third_party/tensorpipe -I/var/lib/jenkins/workspace/build/third_party/tensorpipe -I/var/lib/jenkins/workspace/third_party/tensorpipe/third_party/libnop/include -I/var/lib/jenkins/workspace/third_party/fmt/include -I/var/lib/jenkins/workspace/third_party/flatbuffers/include -isystem /var/lib/jenkins/workspace/build/third_party/gloo -isystem /var/lib/jenkins/workspace/cmake/../third_party/gloo -isystem /var/lib/jenkins/workspace/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /var/lib/jenkins/workspace/cmake/../third_party/googletest/googlemock/include -isystem /var/lib/jenkins/workspace/cmake/../third_party/googletest/googletest/include -isystem /var/lib/jenkins/workspace/third_party/protobuf/src -isystem /var/lib/jenkins/workspace/third_party/gemmlowp -isystem /var/lib/jenkins/workspace/third_party/neon2sse -isystem /var/lib/jenkins/workspace/third_party/XNNPACK/include -isystem /var/lib/jenkins/workspace/third_party/ittapi/include -isystem /var/lib/jenkins/workspace/cmake/../third_party/eigen -isystem /var/lib/jenkins/workspace/build/include -march=native --online-instrumentation=up-down --inst-fma -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=braced-scalar-init -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wvla-extension -Wnewline-eof -Winconsistent-missing-override -Winconsistent-missing-destructor-override -Wno-range-loop-analysis -Wno-pass-failed -Wno-error=pedantic -Wno-error=old-style-cast -Wno-error=inconsistent-missing-override -Wno-error=inconsistent-missing-destructor-override -Wconstant-conversion -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-missing-braces -Wunused-lambda-capture -Qunused-arguments -fcolor-diagnostics -faligned-new -fno-math-errno -fno-trapping-math -Werror=format -DHAVE_AVX512_CPU_DEFINITION -DHAVE_AVX2_CPU_DEFINITION -O3 -DNDEBUG -DNDEBUG -std=gnu++17 -fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -DTH_HAVE_THREAD -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-type-limits -Wno-array-bounds -Wno-strict-overflow -Wno-strict-aliasing -Wno-missing-braces -Wno-range-loop-analysis -fvisibility=hidden -O2 -Wmissing-prototypes -Werror=missing-prototypes -pthread -DASMJIT_STATIC -fopenmp=libiomp5 -MD -MT caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/mkl/SpectralOps.cpp.o -MF CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/mkl/SpectralOps.cpp.o.d -o CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/mkl/SpectralOps.cpp.o -c /var/lib/jenkins/workspace/aten/src/ATen/native/mkl/SpectralOps.cpp diff --git a/experiments/online-instrumentation/numerics/random_sum/test.c b/experiments/online-instrumentation/numerics/random_sum/test.c new file mode 100644 index 00000000..7fd89375 --- /dev/null +++ b/experiments/online-instrumentation/numerics/random_sum/test.c @@ -0,0 +1,23 @@ +// harmonic series +#include +#include +#include + +#ifndef REAL +#define REAL double +#endif + +int main(int argc, char *argv[]) { + + int n = atoi(argv[1]); + int seed = atoi(argv[2]); + + srand(seed); + + REAL sum = 1; + for (int i = 1; i <= n; i++) { + sum += 0.1; + } + fprintf(stderr, "%.13a\n", sum); + return 0; +} \ No newline at end of file diff --git a/experiments/online-instrumentation/numerics/random_sum/test.sh b/experiments/online-instrumentation/numerics/random_sum/test.sh new file mode 100755 index 00000000..4fe93468 --- /dev/null +++ b/experiments/online-instrumentation/numerics/random_sum/test.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +for type in float double; do + + echo "Compiling test.c with $type precision" + + verificarlo-c test.c -DREAL=$type --online-instrumentation=up-down -O0 -o test_O0_ud_$type + verificarlo-c test.c -DREAL=$type --online-instrumentation=up-down -O3 -o test_O3_ud_$type + + # rm -rf ud_results_${type}_O0.*.txt + # rm -rf ud_results_${type}_O3.*.txt + + # echo "Running up-down instrumentation" + # for i in {0..10}; do + # ./test_O0 $N $RANDOM >ud_results_${type}_O0.$i.txt + # ./test_O3 $N $RANDOM >ud_results_${type}_O3.$i.txt + # done + + verificarlo-c test.c -DREAL=$type --online-instrumentation=sr -O0 -o test_O0_sr_$type -lm + verificarlo-c test.c -DREAL=$type --online-instrumentation=sr -O3 -o test_O3_sr_$type -lm + + # rm -rf sr_results_${type}_O0.*.txt + # rm -rf sr_results_${type}_O3.*.txt + + # echo "Running sr instrumentation" + # for i in {0..10}; do + # ./test_O0 $N $RANDOM >sr_results_${type}_O0.$i.txt + # ./test_O3 $N $RANDOM >sr_results_${type}_O3.$i.txt + # done + + verificarlo-c test.c -DREAL=$type -o test_O0_ieee_$type + verificarlo-c test.c -DREAL=$type -o test_O3_ieee_$type + + cp test_O0_ieee_$type test_O0_mca_$type + cp test_O3_ieee_$type test_O3_mca_$type + + # rm -rf ieee_results_${type}_O0.*.txt + # rm -rf ieee_results_${type}_O3.*.txt + + # export VFC_BACKENDS_LOGGER=False + # export VFC_BACKENDS_SILENT_LOAD=True + # export VFC_BACKENDS="libinterflop_ieee.so" + + # echo "Running ieee instrumentation" + # for i in {0..10}; do + # ./test_O0 $N $RANDOM >ieee_results_${type}_O0.$i.txt + # ./test_O3 $N $RANDOM >ieee_results_${type}_O3.$i.txt + # done + + # rm -rf mca_results_${type}_O0.*.txt + # rm -rf mca_results_${type}_O3.*.txt + + # export VFC_BACKENDS="libinterflop_mca.so -m rr" + + # echo "Running mca instrumentation" + # for i in {0..10}; do + # ./test_O0 $N $RANDOM >mca_results_${type}_O0.$i.txt + # ./test_O3 $N $RANDOM >mca_results_${type}_O3.$i.txt + # done + +done + +function run() { + N=1000000 + + local mode=$1 + local type=$2 + local backend=$3 + + rm -rf ${mode}_results_${type}_O0.*.txt + rm -rf ${mode}_results_${type}_O3.*.txt + + export VFC_BACKENDS_LOGGER=False + export VFC_BACKENDS_SILENT_LOAD=True + export VFC_BACKENDS=$backend + + echo "Running ${mode} instrumentation" + perf stat -o 00_${mode}_${type}.perf -r 10 -- ./test_O0_${mode}_${type} $N 23 2>/dev/null + perf stat -o 03_${mode}_${type}.perf -r 10 -- ./test_O3_${mode}_${type} $N 23 2>/dev/null + +} + +export -f run + +parallel --progress -j 1 "run {1} {2} {3}" ::: ud sr ieee mca ::: float double ::: "libinterflop_ieee.so" "libinterflop_mca.so -m rr" diff --git a/src/libvfcinstrumentonline/Makefile.am b/src/libvfcinstrumentonline/Makefile.am index 9c162d90..0c7e47af 100644 --- a/src/libvfcinstrumentonline/Makefile.am +++ b/src/libvfcinstrumentonline/Makefile.am @@ -9,3 +9,5 @@ endif libvfcinstrumentonline_la_CXXFLAGS = @LLVM_CPPFLAGS@ -I@INTERFLOP_INCLUDEDIR@ -Wfatal-errors libvfcinstrumentonline_la_LDFLAGS = @LLVM_LDFLAGS@ libvfcinstrumentonline_la_SOURCES = libVFCInstrumentOnline.cpp + +include_HEADERS = rand.cpp xoroshiro128+.cpp shishua.h \ No newline at end of file diff --git a/src/libvfcinstrumentonline/libVFCInstrumentOnline.cpp b/src/libvfcinstrumentonline/libVFCInstrumentOnline.cpp index a0dd79bf..9e1521ca 100644 --- a/src/libvfcinstrumentonline/libVFCInstrumentOnline.cpp +++ b/src/libvfcinstrumentonline/libVFCInstrumentOnline.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" +#include "llvm/Linker/Linker.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" @@ -31,6 +32,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include +#include #include #include #include @@ -45,6 +47,7 @@ #else #include #endif +#include "llvm/IR/Mangler.h" #include #include #include @@ -100,10 +103,10 @@ static cl::opt VfclibInstExcludeFile( cl::desc("Do not instrument modules / functions in file ExcludeNameFile "), cl::value_desc("ExcludeNameFile"), cl::init("")); -// static cl::opt -// VfclibInstVfcwrapper("vfclibinst-vfcwrapper-file", -// cl::desc("Name of the vfcwrapper IR file "), -// cl::value_desc("VfcwrapperIRFile"), cl::init("")); +static cl::opt VfclibInstSRIRFile( + "vfclibinst-sr-file", + cl::desc("Name of the IR file that contains the sr operators"), + cl::value_desc("SRIRFile"), cl::init("")); static cl::opt VfclibInstVerbose("vfclibinst-verbose", cl::desc("Activate verbose mode"), @@ -129,6 +132,10 @@ static cl::opt cl::desc("Instrument floating point fma"), cl::value_desc("InstrumentFMA"), cl::init(false)); +static cl::opt VfclibInstDebug("vfclibinst-debug", + cl::desc("Activate debug mode"), + cl::value_desc("Debug"), cl::init(false)); + /* pointer that hold the vfcwrapper Module */ // static Module *vfcwrapperM = nullptr; @@ -149,8 +156,9 @@ std::map validTypesMap = { /* valid vector sizes to instrument */ const std::set validVectorSizes = {2, 4, 8, 16, 32, 64}; -/* SHISHUA buffer size */ -const unsigned shishua_buffer_size = 256; +std::map demangledNamesToMangled; +std::map demangledShortNamesToMangled; +std::set functionsToExclude; struct VfclibInst : public ModulePass { static char ID; @@ -298,9 +306,71 @@ struct VfclibInst : public ModulePass { return std::regex(moduleRegex); } + /* Load vfcwrapper.ll Module */ + std::unique_ptr loadVfcwrapperIR(Module &M) { + SMDiagnostic err; + auto newM = parseIRFile(VfclibInstSRIRFile, err, M.getContext()); + if (newM.get() == nullptr) { + err.print(VfclibInstSRIRFile.c_str(), errs()); + report_fatal_error("libVFCInstrumentOnline fatal error"); + } + return newM; + } + + void getDemangledNamesLibSR(Module *M) { + for (auto &F : M->functions()) { + if (F.isDeclaration()) { + continue; + } + const std::string &mangled_name = F.getName().str(); + functionsToExclude.insert(mangled_name); + + std::string demangled_name = get_demangled_name(mangled_name); + demangledNamesToMangled[demangled_name] = mangled_name; + + size_t parenPos = demangled_name.find('('); + std::string demangled_name_short = + (parenPos != std::string::npos) ? demangled_name.substr(0, parenPos) + : demangled_name; + + demangledShortNamesToMangled[demangled_name_short] = mangled_name; + } + } + + void printDemangledNamesLibsSR() { + errs() << "Demangled names\n"; + for (auto &p : demangledNamesToMangled) { + errs() << p.first << " : " << p.second << "\n"; + } + } + + void printDemangledNamesLibsSRShort() { + errs() << "Short demangled names\n"; + for (auto &p : demangledShortNamesToMangled) { + errs() << p.first << " : " << p.second << "\n"; + } + } + bool runOnModule(Module &M) { bool modified = false; + auto ir = loadVfcwrapperIR(M); + // if ir is null, an error message has already been printed + if (ir.get() == nullptr) { + report_fatal_error("libVFCInstrumentOnline fatal error while reading IR"); + } + getDemangledNamesLibSR(ir.get()); + + if (VfclibInstDebug) { + printDemangledNamesLibsSR(); + printDemangledNamesLibsSRShort(); + } + + if (Linker::linkModules(M, std::move(loadVfcwrapperIR(M)))) { + report_fatal_error( + "libVFCInstrumentOnline fatal error when linking modules"); + } + // Parse both included and excluded function set std::regex includeFunctionRgx = parseFunctionSetFile(M, VfclibInstIncludeFile); @@ -319,6 +389,11 @@ struct VfclibInst : public ModulePass { const std::string &name = F.getName().str(); + // Function in the sr library to exclude + if (functionsToExclude.find(name) != functionsToExclude.end()) { + continue; + } + // Included-list if (std::regex_match(name, includeFunctionRgx)) { functions.push_back(&F); @@ -344,71 +419,10 @@ struct VfclibInst : public ModulePass { modified |= runOnFunction(M, *F); } - if (modified) { - Function *init = getOrCreateInitRNGFunction(M); - insertRNGInitFunction(init); - } - // runOnModule must return true if the pass modifies the IR return modified; } - void insertRNGInitFunction(Function *init) { - Module *M = init->getParent(); - Function *F = init; - - // Create the types for the global variable - IntegerType *Int32Ty = Type::getInt32Ty(M->getContext()); - PointerType *VoidPtrTy = Type::getInt8PtrTy(M->getContext()); - FunctionType *VoidFuncTy = - FunctionType::get(Type::getVoidTy(M->getContext()), false); - PointerType *VoidFuncPtrTy = PointerType::get(VoidFuncTy, 0); - - // Create the new elements to be added - ConstantInt *NewPriority = ConstantInt::get(Int32Ty, 65534); - Function *InitFunc = getOrInsertFunction(M, "_sr_init_rng", VoidFuncTy); - - Constant *NullPtr = ConstantPointerNull::get(VoidPtrTy); - - // Create the struct type - StructType *ElemTy = StructType::get(Int32Ty, VoidFuncPtrTy, VoidPtrTy); - - // Retrieve the existing llvm.global_ctors variable - GlobalVariable *GlobalCtors = M->getGlobalVariable("llvm.global_ctors"); - - std::vector Elements; - - if (GlobalCtors) { - // Get the existing initializer - if (ConstantArray *InitList = - dyn_cast(GlobalCtors->getInitializer())) { - for (unsigned i = 0; i < InitList->getNumOperands(); ++i) { - Elements.push_back(InitList->getOperand(i)); - } - } - } - - // Add the new element to the list - Constant *NewElem = - ConstantStruct::get(ElemTy, NewPriority, InitFunc, NullPtr); - Elements.push_back(NewElem); - - // Create the new initializer - ArrayType *ArrayTy = ArrayType::get(ElemTy, Elements.size()); - Constant *NewInit = ConstantArray::get(ArrayTy, Elements); - - GlobalVariable *GV = - new GlobalVariable(*M, ArrayTy, false, GlobalValue::AppendingLinkage, - NewInit, "llvm.global_ctors"); - - if (GlobalCtors != nullptr) { - GlobalCtors->replaceAllUsesWith(GV); - auto name = GlobalCtors->getName(); - GlobalCtors->eraseFromParent(); - GV->setName(name); - } - } - /* Check if Instruction I is a valid instruction to replace; scalar case */ bool isValidScalarInstruction(Type *opType) { bool isValidType = @@ -421,8 +435,12 @@ struct VfclibInst : public ModulePass { /* Check if Instruction I is a valid instruction to replace; vector case */ bool isValidVectorInstruction(Type *opType) { - VectorType *vecType = static_cast(opType); + if (opType == nullptr) { + errs() << "Unsupported operand type\n"; + } + auto vecType = static_cast(opType); auto baseType = vecType->getScalarType(); + #if LLVM_VERSION_MAJOR >= 13 if (isa(vecType)) report_fatal_error("Scalable vector type are not supported"); @@ -462,841 +480,6 @@ struct VfclibInst : public ModulePass { return modified; } - Value *insertRNGdouble01Call(IRBuilder<> &Builder, Instruction *I) { - Module *M = I->getModule(); - std::string randName = "get_rand_double01"; - Type *voidTy = Type::getVoidTy(I->getContext()); - Type *doubleTy = Type::getDoubleTy(I->getContext()); - FunctionType *funcType = FunctionType::get(doubleTy, {}, false); -#if LLVM_VERSION_MAJOR < 9 - Constant *RNG = M->getOrInsertFunction(randName, funcType); -#else - FunctionCallee RNG = M->getOrInsertFunction(randName, funcType); -#endif - Instruction *call = Builder.CreateCall(RNG); - return call; - } - - Type *getFPAsIntType(Type *type) { - if (type->isFloatTy()) { - return Type::getInt32Ty(type->getContext()); - } else if (type->isDoubleTy()) { - return Type::getInt64Ty(type->getContext()); - } else { - const std::string function_name = __func__; - errs() << "[" + function_name + "] " + "Unsupported type: " << *type - << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - Type *getFPAsUIntType(Type *type) { - if (type->isFloatTy()) { - return Type::getInt32Ty(type->getContext()); - } else if (type->isDoubleTy()) { - return Type::getInt64Ty(type->getContext()); - } else { - const std::string function_name = __func__; - errs() << "[" + function_name + "] " + "Unsupported type: " << *type - << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - Constant *getUIntValue(Type *type, uint64_t value) { - Type *intTy = getFPAsIntType(type); - if (type->isFloatTy() or type->isDoubleTy()) { - return ConstantInt::get(intTy, value); - } else { - const std::string function_name = __func__; - errs() << "[" + function_name + "] " + "Unsupported type: " << *type - << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - Constant *getIntValue(Type *type, uint64_t value) { - Type *intTy = getFPAsIntType(type); - if (type->isFloatTy() or type->isDoubleTy()) { - return ConstantInt::get(intTy, value, true); - } else { - const std::string function_name = __func__; - errs() << "[" + function_name + "] " + "Unsupported type: " << *type - << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - std::string sr_hook_name(Instruction &I) { - switch (I.getOpcode()) { - case Instruction::FAdd: - return "add2"; - case Instruction::FSub: - // In LLVM IR the FSub instruction is used to represent FNeg - return "sub2"; - case Instruction::FMul: - return "mul2"; - case Instruction::FDiv: - return "div2"; - default: - errs() << "Unsupported opcode: " << I.getOpcodeName() << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - std::string getTypeName(Type *type) { - - if (type->isVectorTy()) { - VectorType *VT = dyn_cast(type); - type = VT->getElementType(); - } - - switch (type->getTypeID()) { - case Type::VoidTyID: - return "void"; - case Type::HalfTyID: - return "half"; - case Type::FloatTyID: - return "float"; - case Type::DoubleTyID: - return "double"; - case Type::X86_FP80TyID: - return "x86_fp80"; - case Type::FP128TyID: - return "fp128"; - case Type::PPC_FP128TyID: - return "ppc_fp128"; - case Type::IntegerTyID: - return "integer"; - case Type::FunctionTyID: - return "function"; - case Type::StructTyID: - return "struct"; - case Type::ArrayTyID: - return "array"; - case Type::PointerTyID: - return "pointer"; - default: - return "unknown"; - } - } - - void insertRandUint64XoroshiroCall(IRBuilder<> &Builder, Instruction *I, - Value **rand) { - Module *M = I->getModule(); - Type *uint64Ty = Type::getInt64Ty(M->getContext()); - - // Get the global variables, set true to search for local variables - GlobalVariable *rng_state0 = M->getGlobalVariable("rng_state.0", true); - GlobalVariable *rng_state1 = M->getGlobalVariable("rng_state.1", true); - - if (rng_state0 == nullptr) { - rng_state0 = new GlobalVariable( - *M, - /* type */ uint64Ty, - /* isConstant */ false, - /* linkage */ GlobalValue::InternalLinkage, - /* initializer */ ConstantInt::get(uint64Ty, 0), - /* name */ "rng_state.0", - /* insertbefore */ nullptr, - /* threadmode */ GlobalValue::ThreadLocalMode::GeneralDynamicTLSModel, - /* addresspace */ 0, - /* isExternallyInitialized */ false); - } - if (rng_state1 == nullptr) { - rng_state1 = new GlobalVariable( - *M, - /* type */ uint64Ty, - /* isConstant */ false, - /* linkage */ GlobalValue::InternalLinkage, - /* initializer */ ConstantInt::get(uint64Ty, 0), - /* name */ "rng_state.1", - /* insertbefore */ nullptr, - /* threadmode */ GlobalValue::ThreadLocalMode::GeneralDynamicTLSModel, - /* addresspace */ 0, - /* isExternallyInitialized */ false); - } - - // Get rand uint64 - Value *rng_state0_load = Builder.CreateLoad(uint64Ty, rng_state0); - Value *rng_state1_load = Builder.CreateLoad(uint64Ty, rng_state1); - Value *add = Builder.CreateAdd(rng_state1_load, rng_state0_load); - Value *shl = Builder.CreateShl(add, 17); - Value *lshr = Builder.CreateLShr(add, 47); - Value *add2 = Builder.CreateAdd(lshr, rng_state0_load); - Value *or2 = Builder.CreateOr(add2, shl); - Value *xor2 = Builder.CreateXor(rng_state1_load, rng_state0_load); - Value *shl2 = Builder.CreateShl(rng_state0_load, 49); - Value *lshr2 = Builder.CreateLShr(rng_state0_load, 15); - Value *xor3 = Builder.CreateXor(xor2, lshr2); - Value *shl3 = Builder.CreateShl(xor2, 21); - Value *xor4 = Builder.CreateXor(xor3, shl3); - *rand = Builder.CreateOr(xor4, shl2, "rand_uint64"); - Builder.CreateStore(*rand, rng_state0); -#if LLVM_VERSION_MAJOR < 9 - std::vector args = {xor2, xor2, Builder.getInt64(28)}; - Value *fshl = Builder.CreateIntrinsic(Intrinsic::fshl, args); -#else - Value *fshl = Builder.CreateIntrinsic(Intrinsic::fshl, {xor2->getType()}, - {xor2, xor2, Builder.getInt64(28)}); -#endif - Builder.CreateStore(fshl, rng_state1); - } - - template - Value *getVectorConstant(Module *M, std::initializer_list values) { - // define Type *ty depending on the type of T - Type *ty = nullptr; - if (std::is_same::value or std::is_same::value) { - ty = Type::getInt32Ty(M->getContext()); - } else if (std::is_same::value or - std::is_same::value) { - ty = Type::getInt64Ty(M->getContext()); - } else { - errs() << "Unsupported type: " << typeid(T).name() << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - - std::vector constants; - for (auto value : values) { - constants.push_back(ConstantInt::get(ty, value)); - } - return ConstantVector::get(constants); - } - - template - Value *getVectorConstant(Module *M, std::vector values) { - // define Type *ty depending on the type of T - Type *ty = nullptr; - if (std::is_same::value or std::is_same::value) { - ty = Type::getInt32Ty(M->getContext()); - } else if (std::is_same::value or - std::is_same::value) { - ty = Type::getInt64Ty(M->getContext()); - } else { - errs() << "Unsupported type: " << typeid(T).name() << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - - std::vector constants; - for (auto value : values) { - constants.push_back(ConstantInt::get(ty, value)); - } - return ConstantVector::get(constants); - } - - // clang-format off -// ; Function Attrs: nofree nosync nounwind uwtable -// define dso_local void @prng_init(%struct.prng_state* nocapture noundef %0, i64* nocapture noundef readonly %1) local_unnamed_addr #0 { -// %3 = bitcast %struct.prng_state* %0 to i8* -// tail call void @llvm.memset.p0i8.i64(i8* noundef nonnull align 32 dereferenceable(288) %3, i8 0, i64 288, i1 false) -// %4 = bitcast i64* %1 to <2 x i64>* -// %5 = load <2 x i64>, <2 x i64>* %4, align 8, !tbaa !5 -// %6 = xor <2 x i64> %5, -// %7 = shufflevector <2 x i64> %6, <2 x i64> poison, <4 x i32> -// %8 = shufflevector <4 x i64> %7, <4 x i64> , <4 x i32> -// %9 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 0, i64 0 -// store <4 x i64> %8, <4 x i64>* %9, align 32, !tbaa !9 -// %10 = getelementptr inbounds i64, i64* %1, i64 2 -// %11 = bitcast i64* %10 to <2 x i64>* -// %12 = load <2 x i64>, <2 x i64>* %11, align 8, !tbaa !5 -// %13 = xor <2 x i64> %12, -// %14 = shufflevector <2 x i64> %13, <2 x i64> poison, <4 x i32> -// %15 = shufflevector <4 x i64> %14, <4 x i64> , <4 x i32> -// %16 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 0, i64 1 -// store <4 x i64> %15, <4 x i64>* %16, align 32, !tbaa !9 -// %17 = bitcast i64* %10 to <2 x i64>* -// %18 = load <2 x i64>, <2 x i64>* %17, align 8, !tbaa !5 -// %19 = xor <2 x i64> %18, -// %20 = shufflevector <2 x i64> %19, <2 x i64> poison, <4 x i32> -// %21 = shufflevector <4 x i64> %20, <4 x i64> , <4 x i32> -// %22 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 0, i64 2 -// store <4 x i64> %21, <4 x i64>* %22, align 32, !tbaa !9 -// %23 = bitcast i64* %1 to <2 x i64>* -// %24 = load <2 x i64>, <2 x i64>* %23, align 8, !tbaa !5 -// %25 = xor <2 x i64> %24, -// %26 = shufflevector <2 x i64> %25, <2 x i64> poison, <4 x i32> -// %27 = shufflevector <4 x i64> %26, <4 x i64> , <4 x i32> -// br label %35 - -// 28: ; preds = %35 -// %29 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 0, i64 3 -// %30 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 2 -// %31 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 1, i64 3 -// %32 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 1, i64 2 -// %33 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 1, i64 1 -// %34 = getelementptr inbounds %struct.prng_state, %struct.prng_state* %0, i64 0, i32 1, i64 0 -// store <4 x i64> %228, <4 x i64>* %9, align 32, !tbaa !9 -// store <4 x i64> %227, <4 x i64>* %16, align 32, !tbaa !9 -// store <4 x i64> %226, <4 x i64>* %22, align 32, !tbaa !9 -// store <4 x i64> %225, <4 x i64>* %29, align 32, !tbaa !9 -// store <4 x i64> %229, <4 x i64>* %30, align 32, !tbaa !9 -// store <4 x i64> %225, <4 x i64>* %34, align 32, !tbaa !9 -// store <4 x i64> %226, <4 x i64>* %33, align 32, !tbaa !9 -// store <4 x i64> %227, <4 x i64>* %32, align 32, !tbaa !9 -// store <4 x i64> %228, <4 x i64>* %31, align 32, !tbaa !9 -// ret void - -// 35: ; preds = %2, %35 -// %36 = phi i64 [ 0, %2 ], [ %230, %35 ] -// %37 = phi <4 x i64> [ %8, %2 ], [ %228, %35 ] -// %38 = phi <4 x i64> [ %15, %2 ], [ %227, %35 ] -// %39 = phi <4 x i64> [ %21, %2 ], [ %226, %35 ] -// %40 = phi <4 x i64> [ %27, %2 ], [ %225, %35 ] -// %41 = phi <4 x i64> [ zeroinitializer, %2 ], [ %229, %35 ] -// %42 = add <4 x i64> %41, %38 -// %43 = add <4 x i64> %41, %40 -// %44 = or <4 x i64> %41, -// %45 = lshr <4 x i64> %37, -// %46 = lshr <4 x i64> %42, -// %47 = lshr <4 x i64> %39, -// %48 = lshr <4 x i64> %43, -// %49 = bitcast <4 x i64> %37 to <8 x i32> -// %50 = shufflevector <8 x i32> %49, <8 x i32> poison, <8 x i32> -// %51 = bitcast <8 x i32> %50 to <4 x i64> -// %52 = bitcast <4 x i64> %42 to <8 x i32> -// %53 = shufflevector <8 x i32> %52, <8 x i32> poison, <8 x i32> -// %54 = bitcast <8 x i32> %53 to <4 x i64> -// %55 = bitcast <4 x i64> %39 to <8 x i32> -// %56 = shufflevector <8 x i32> %55, <8 x i32> poison, <8 x i32> -// %57 = bitcast <8 x i32> %56 to <4 x i64> -// %58 = bitcast <4 x i64> %43 to <8 x i32> -// %59 = shufflevector <8 x i32> %58, <8 x i32> poison, <8 x i32> -// %60 = bitcast <8 x i32> %59 to <4 x i64> -// %61 = add <4 x i64> %45, %51 -// %62 = add <4 x i64> %47, %57 -// %63 = add <4 x i64> %46, %44 -// %64 = add <4 x i64> %63, %54 -// %65 = add <4 x i64> %48, %44 -// %66 = add <4 x i64> %65, %60 -// %67 = add <4 x i64> %41, -// %68 = lshr <4 x i64> %61, -// %69 = lshr <4 x i64> %64, -// %70 = lshr <4 x i64> %62, -// %71 = lshr <4 x i64> %66, -// %72 = bitcast <4 x i64> %61 to <8 x i32> -// %73 = shufflevector <8 x i32> %72, <8 x i32> poison, <8 x i32> -// %74 = bitcast <8 x i32> %73 to <4 x i64> -// %75 = bitcast <4 x i64> %64 to <8 x i32> -// %76 = shufflevector <8 x i32> %75, <8 x i32> poison, <8 x i32> -// %77 = bitcast <8 x i32> %76 to <4 x i64> -// %78 = bitcast <4 x i64> %62 to <8 x i32> -// %79 = shufflevector <8 x i32> %78, <8 x i32> poison, <8 x i32> -// %80 = bitcast <8 x i32> %79 to <4 x i64> -// %81 = bitcast <4 x i64> %66 to <8 x i32> -// %82 = shufflevector <8 x i32> %81, <8 x i32> poison, <8 x i32> -// %83 = bitcast <8 x i32> %82 to <4 x i64> -// %84 = add <4 x i64> %68, %74 -// %85 = add <4 x i64> %70, %80 -// %86 = add <4 x i64> %69, %67 -// %87 = add <4 x i64> %86, %77 -// %88 = add <4 x i64> %71, %67 -// %89 = add <4 x i64> %88, %83 -// %90 = add <4 x i64> %41, -// %91 = lshr <4 x i64> %84, -// %92 = lshr <4 x i64> %87, -// %93 = lshr <4 x i64> %85, -// %94 = lshr <4 x i64> %89, -// %95 = bitcast <4 x i64> %84 to <8 x i32> -// %96 = shufflevector <8 x i32> %95, <8 x i32> poison, <8 x i32> -// %97 = bitcast <8 x i32> %96 to <4 x i64> -// %98 = bitcast <4 x i64> %87 to <8 x i32> -// %99 = shufflevector <8 x i32> %98, <8 x i32> poison, <8 x i32> -// %100 = bitcast <8 x i32> %99 to <4 x i64> -// %101 = bitcast <4 x i64> %85 to <8 x i32> -// %102 = shufflevector <8 x i32> %101, <8 x i32> poison, <8 x i32> -// %103 = bitcast <8 x i32> %102 to <4 x i64> -// %104 = bitcast <4 x i64> %89 to <8 x i32> -// %105 = shufflevector <8 x i32> %104, <8 x i32> poison, <8 x i32> -// %106 = bitcast <8 x i32> %105 to <4 x i64> -// %107 = add <4 x i64> %91, %97 -// %108 = add <4 x i64> %93, %103 -// %109 = add <4 x i64> %92, %90 -// %110 = add <4 x i64> %109, %100 -// %111 = add <4 x i64> %94, %90 -// %112 = add <4 x i64> %111, %106 -// %113 = add <4 x i64> %41, -// %114 = lshr <4 x i64> %107, -// %115 = lshr <4 x i64> %110, -// %116 = lshr <4 x i64> %108, -// %117 = lshr <4 x i64> %112, -// %118 = bitcast <4 x i64> %107 to <8 x i32> -// %119 = shufflevector <8 x i32> %118, <8 x i32> poison, <8 x i32> -// %120 = bitcast <8 x i32> %119 to <4 x i64> -// %121 = bitcast <4 x i64> %110 to <8 x i32> -// %122 = shufflevector <8 x i32> %121, <8 x i32> poison, <8 x i32> -// %123 = bitcast <8 x i32> %122 to <4 x i64> -// %124 = bitcast <4 x i64> %108 to <8 x i32> -// %125 = shufflevector <8 x i32> %124, <8 x i32> poison, <8 x i32> -// %126 = bitcast <8 x i32> %125 to <4 x i64> -// %127 = bitcast <4 x i64> %112 to <8 x i32> -// %128 = shufflevector <8 x i32> %127, <8 x i32> poison, <8 x i32> -// %129 = bitcast <8 x i32> %128 to <4 x i64> -// %130 = add <4 x i64> %114, %120 -// %131 = add <4 x i64> %116, %126 -// %132 = add <4 x i64> %115, %113 -// %133 = add <4 x i64> %132, %123 -// %134 = add <4 x i64> %117, %113 -// %135 = add <4 x i64> %134, %129 -// %136 = add <4 x i64> %41, -// %137 = lshr <4 x i64> %130, -// %138 = lshr <4 x i64> %133, -// %139 = lshr <4 x i64> %131, -// %140 = lshr <4 x i64> %135, -// %141 = bitcast <4 x i64> %130 to <8 x i32> -// %142 = shufflevector <8 x i32> %141, <8 x i32> poison, <8 x i32> -// %143 = bitcast <8 x i32> %142 to <4 x i64> -// %144 = bitcast <4 x i64> %133 to <8 x i32> -// %145 = shufflevector <8 x i32> %144, <8 x i32> poison, <8 x i32> -// %146 = bitcast <8 x i32> %145 to <4 x i64> -// %147 = bitcast <4 x i64> %131 to <8 x i32> -// %148 = shufflevector <8 x i32> %147, <8 x i32> poison, <8 x i32> -// %149 = bitcast <8 x i32> %148 to <4 x i64> -// %150 = bitcast <4 x i64> %135 to <8 x i32> -// %151 = shufflevector <8 x i32> %150, <8 x i32> poison, <8 x i32> -// %152 = bitcast <8 x i32> %151 to <4 x i64> -// %153 = add <4 x i64> %137, %143 -// %154 = add <4 x i64> %139, %149 -// %155 = add <4 x i64> %138, %136 -// %156 = add <4 x i64> %155, %146 -// %157 = add <4 x i64> %140, %136 -// %158 = add <4 x i64> %157, %152 -// %159 = add <4 x i64> %41, -// %160 = lshr <4 x i64> %153, -// %161 = lshr <4 x i64> %156, -// %162 = lshr <4 x i64> %154, -// %163 = lshr <4 x i64> %158, -// %164 = bitcast <4 x i64> %153 to <8 x i32> -// %165 = shufflevector <8 x i32> %164, <8 x i32> poison, <8 x i32> -// %166 = bitcast <8 x i32> %165 to <4 x i64> -// %167 = bitcast <4 x i64> %156 to <8 x i32> -// %168 = shufflevector <8 x i32> %167, <8 x i32> poison, <8 x i32> -// %169 = bitcast <8 x i32> %168 to <4 x i64> -// %170 = bitcast <4 x i64> %154 to <8 x i32> -// %171 = shufflevector <8 x i32> %170, <8 x i32> poison, <8 x i32> -// %172 = bitcast <8 x i32> %171 to <4 x i64> -// %173 = bitcast <4 x i64> %158 to <8 x i32> -// %174 = shufflevector <8 x i32> %173, <8 x i32> poison, <8 x i32> -// %175 = bitcast <8 x i32> %174 to <4 x i64> -// %176 = add <4 x i64> %160, %166 -// %177 = add <4 x i64> %162, %172 -// %178 = add <4 x i64> %161, %159 -// %179 = add <4 x i64> %178, %169 -// %180 = add <4 x i64> %163, %159 -// %181 = add <4 x i64> %180, %175 -// %182 = add <4 x i64> %41, -// %183 = lshr <4 x i64> %176, -// %184 = lshr <4 x i64> %179, -// %185 = lshr <4 x i64> %177, -// %186 = lshr <4 x i64> %181, -// %187 = bitcast <4 x i64> %176 to <8 x i32> -// %188 = shufflevector <8 x i32> %187, <8 x i32> poison, <8 x i32> -// %189 = bitcast <8 x i32> %188 to <4 x i64> -// %190 = bitcast <4 x i64> %179 to <8 x i32> -// %191 = shufflevector <8 x i32> %190, <8 x i32> poison, <8 x i32> -// %192 = bitcast <8 x i32> %191 to <4 x i64> -// %193 = bitcast <4 x i64> %177 to <8 x i32> -// %194 = shufflevector <8 x i32> %193, <8 x i32> poison, <8 x i32> -// %195 = bitcast <8 x i32> %194 to <4 x i64> -// %196 = bitcast <4 x i64> %181 to <8 x i32> -// %197 = shufflevector <8 x i32> %196, <8 x i32> poison, <8 x i32> -// %198 = bitcast <8 x i32> %197 to <4 x i64> -// %199 = add <4 x i64> %183, %189 -// %200 = add <4 x i64> %185, %195 -// %201 = add <4 x i64> %184, %182 -// %202 = add <4 x i64> %201, %192 -// %203 = add <4 x i64> %186, %182 -// %204 = add <4 x i64> %203, %198 -// %205 = lshr <4 x i64> %199, -// %206 = lshr <4 x i64> %202, -// %207 = lshr <4 x i64> %200, -// %208 = lshr <4 x i64> %204, -// %209 = bitcast <4 x i64> %199 to <8 x i32> -// %210 = shufflevector <8 x i32> %209, <8 x i32> poison, <8 x i32> -// %211 = bitcast <8 x i32> %210 to <4 x i64> -// %212 = bitcast <4 x i64> %202 to <8 x i32> -// %213 = shufflevector <8 x i32> %212, <8 x i32> poison, <8 x i32> -// %214 = bitcast <8 x i32> %213 to <4 x i64> -// %215 = bitcast <4 x i64> %200 to <8 x i32> -// %216 = shufflevector <8 x i32> %215, <8 x i32> poison, <8 x i32> -// %217 = bitcast <8 x i32> %216 to <4 x i64> -// %218 = bitcast <4 x i64> %204 to <8 x i32> -// %219 = shufflevector <8 x i32> %218, <8 x i32> poison, <8 x i32> -// %220 = bitcast <8 x i32> %219 to <4 x i64> -// %221 = add <4 x i64> %205, %211 -// %222 = add <4 x i64> %206, %214 -// %223 = add <4 x i64> %207, %217 -// %224 = add <4 x i64> %208, %220 -// %225 = xor <4 x i64> %205, %214 -// %226 = xor <4 x i64> %207, %220 -// %227 = xor <4 x i64> %224, %221 -// %228 = xor <4 x i64> %222, %223 -// %229 = add <4 x i64> %41, -// %230 = add nuw nsw i64 %36, 1 -// %231 = icmp eq i64 %230, 13 -// br i1 %231, label %28, label %35, !llvm.loop !10 -// } - // clang-format on - - Value *insertShishuaInit(IRBuilder<> &Builder, Instruction *I) { - Module *M = I->getModule(); - const std::string function_name = "_shishua_init"; - if (Function *shishua = M->getFunction(function_name)) { - return Builder.CreateCall(shishua); - } - - auto voidTy = Type::getVoidTy(M->getContext()); - std::vector argsTy; - argsTy.push_back(Type::getInt32Ty(M->getContext())); - FunctionType *shishuaInitType = FunctionType::get(voidTy, argsTy, false); - - Function *shishuaInit = - getOrInsertFunction(M, function_name, shishuaInitType); - } - - // clang-format off - // ; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind uwtable willreturn - // define dso_local i64 @get_rand_uint64() local_unnamed_addr #4 { - // entry: - // %1 = load i32, i32* @get_rand_uint64.i, align 4, !tbaa !10 - // %2 = and i32 %1, 3 - // %3 = icmp eq i32 %2, 0 - // br i1 %3, label %4, label %22 - // rng: ; preds = %0 - // %5 = load <4 x i64>, <4 x i64>* @rng_state.0, align 32, !tbaa !9 - // %6 = load <4 x i64>, <4 x i64>* @rng_state.3, align 32, !tbaa !9 - // %7 = load <4 x i64>, <4 x i64>* @rng_state.1, align 32, !tbaa !9 - // %8 = load <4 x i64>, <4 x i64>* @rng_state.2, align 32, !tbaa !9 - // store <4 x i64> %8, <4 x i64>* bitcast ([32 x i8]* @get_rand_uint64.buf to <4 x i64>*), align 32, !tbaa !9 - // %9 = add <4 x i64> %7, %6 - // %10 = add <4 x i64> %6, - // %11 = lshr <4 x i64> %5, - // %12 = lshr <4 x i64> %9, - // %13 = bitcast <4 x i64> %5 to <8 x i32> - // %14 = shufflevector <8 x i32> %13, <8 x i32> poison, <8 x i32> - // %15 = bitcast <8 x i32> %14 to <4 x i64> - // %16 = bitcast <4 x i64> %9 to <8 x i32> - // %17 = shufflevector <8 x i32> %16, <8 x i32> poison, <8 x i32> - // %18 = bitcast <8 x i32> %17 to <4 x i64> - // %19 = add <4 x i64> %11, %15 - // %20 = add <4 x i64> %12, %18 - // %21 = xor <4 x i64> %11, %18 - // store <4 x i64> %19, <4 x i64>* @rng_state.0, align 32, !tbaa !9 - // store <4 x i64> %10, <4 x i64>* @rng_state.3, align 32, !tbaa !9 - // store <4 x i64> %20, <4 x i64>* @rng_state.1, align 32, !tbaa !9 - // store <4 x i64> %21, <4 x i64>* @rng_state.2, align 32, !tbaa !9 - // br label %22 - // end: ; preds = %4, %0 - // %23 = phi i32 [ 0, %4 ], [ %1, %0 ] - // %24 = shl nsw i32 %23, 3 - // %25 = sext i32 %24 to i64 - // %26 = getelementptr inbounds [32 x i8], [32 x i8]* @get_rand_uint64.buf, i64 0, i64 %25 - // %27 = bitcast i8* %26 to i64* - // %28 = load i64, i64* %27, align 8 - // %29 = add nsw i32 %23, 1 - // store i32 %29, i32* @get_rand_uint64.i, align 4, !tbaa !10 - // ret i64 %28 - // } - // clang-format on - void insertRandUint64ShishuaCall(IRBuilder<> &Builder, Instruction *I, - Value **rand) { - - int nbBytesRequested = 1; - if (I->getType()->isVectorTy()) { - auto *vecTy = static_cast(I->getType()); -#if LLVM_VERSION_MAJOR >= 11 - auto size = vecTy->getElementCount().getKnownMinValue(); -#else - auto size = vecTy->getElementCount(); -#endif - nbBytesRequested = (int)ceil(size / 8.0); - } - - Module *M = I->getModule(); - const std::string function_name = - "_shishua_uint64" + std::to_string(nbBytesRequested); - if (Function *shishua = M->getFunction(function_name)) { - std::vector args = {Builder.getInt32(nbBytesRequested)}; - *rand = Builder.CreateCall(shishua, args); - return; - } - - BasicBlock *caller = Builder.GetInsertBlock(); - Function *originalFunction = I->getParent()->getParent(); - Type *int8Ty = Type::getInt8Ty(M->getContext()); - Type *int32Ty = Type::getInt32Ty(M->getContext()); - Type *int64Ty = Type::getInt64Ty(M->getContext()); - Type *floatTy = Type::getFloatTy(M->getContext()); - Type *int32x8Ty = GET_VECTOR_TYPE(int32Ty, 8); - Type *int64x4Ty = GET_VECTOR_TYPE(int64Ty, 4); - - FunctionType *shishuType = FunctionType::get(int64Ty, {int32Ty}, false); - - Function *shishua = getOrInsertFunction(M, function_name, shishuType); - - Argument *nbBytesRequestedArg = &*shishua->arg_begin(); - - errs() << "Inserting shishua call\n"; - errs() << "shishua: " << *shishua << "\n"; - BasicBlock *entryBB = - BasicBlock::Create(Builder.getContext(), "entry", shishua); - BasicBlock *rngBB = - BasicBlock::Create(Builder.getContext(), "rng", shishua); - BasicBlock *endBB = - BasicBlock::Create(Builder.getContext(), "end", shishua); - - // entry block - Builder.SetInsertPoint(entryBB); - GlobalVariable *rand_uint64_i = getGVRandUint64Counter(*M); - auto *randUint64ILoad = Builder.CreateLoad(int32Ty, rand_uint64_i); - randUint64ILoad->setAlignment(Align(4)); - // %2 = icmp sgt i32 %1, 120 - auto cstICmpSgt = 0; - switch (nbBytesRequested) { - case 1: - cstICmpSgt = 255; - break; - case 2: - cstICmpSgt = 254; - break; - case 3: - cstICmpSgt = 252; - break; - case 4: - cstICmpSgt = 248; - break; - default: - errs() << "Unsupported number of bytes requested: " << nbBytesRequested - << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - Value *randUint64ICmp = - Builder.CreateICmpSGT(randUint64ILoad, Builder.getInt32(cstICmpSgt)); - Builder.CreateCondBr(randUint64ICmp, rngBB, endBB); - - // rng block - Builder.SetInsertPoint(rngBB); - GlobalVariable *rng_state0 = getGVRNGState(M, "rng_state.0"); - GlobalVariable *rng_state3 = getGVRNGState(M, "rng_state.3"); - GlobalVariable *rng_state1 = getGVRNGState(M, "rng_state.1"); - GlobalVariable *rng_state2 = getGVRNGState(M, "rng_state.2"); - - auto *rngState0Load = Builder.CreateLoad(int64x4Ty, rng_state0); - auto *rngState3Load = Builder.CreateLoad(int64x4Ty, rng_state3); - auto *rngState1Load = Builder.CreateLoad(int64x4Ty, rng_state1); - auto *rngState2Load = Builder.CreateLoad(int64x4Ty, rng_state2); - - rngState0Load->setAlignment(Align(32)); - rngState1Load->setAlignment(Align(32)); - rngState3Load->setAlignment(Align(32)); - rngState2Load->setAlignment(Align(32)); - - GlobalVariable *rand_uint64_buffer = getGVRandUint64Buffer(*M); - Value *bitcast = - Builder.CreateBitCast(rand_uint64_buffer, int64x4Ty->getPointerTo()); - Builder.CreateStore(rngState2Load, bitcast); - Value *add1 = Builder.CreateAdd(rngState1Load, rngState3Load); - Value *cst = getVectorConstant(M, {7, 5, 3, 1}); - Value *add2 = Builder.CreateAdd(rngState3Load, cst); - Value *cst2 = getVectorConstant(M, {1, 1, 1, 1}); - Value *lshr1 = Builder.CreateLShr(rngState0Load, cst2); - Value *cst3 = getVectorConstant(M, {3, 3, 3, 3}); - Value *lshr2 = Builder.CreateLShr(add1, cst3); - Value *bitcast1 = Builder.CreateBitCast(rngState0Load, int32x8Ty); - Value *shufflePermutation1 = - getVectorConstant(M, {5, 6, 7, 0, 1, 2, 3, 4}); - Value *shuffle1 = Builder.CreateShuffleVector( - bitcast1, UndefValue::get(int32x8Ty), shufflePermutation1); - Value *bitcast2 = Builder.CreateBitCast(shuffle1, int64x4Ty); - Value *bitcast3 = Builder.CreateBitCast(add1, int32x8Ty); - Value *shufflePermutation2 = - getVectorConstant(M, {3, 4, 5, 6, 7, 0, 1, 2}); - Value *shuffle2 = Builder.CreateShuffleVector( - bitcast3, UndefValue::get(int32x8Ty), shufflePermutation2); - Value *bitcast4 = Builder.CreateBitCast(shuffle2, int64x4Ty); - Value *add3 = Builder.CreateAdd(lshr1, bitcast2); - Value *add4 = Builder.CreateAdd(lshr2, bitcast4); - Value *xor1 = Builder.CreateXor(lshr1, bitcast4); - Builder.CreateStore(add3, rng_state0); - Builder.CreateStore(add2, rng_state3); - Builder.CreateStore(add4, rng_state1); - Builder.CreateStore(xor1, rng_state2); - Builder.CreateBr(endBB); - - // end block - Builder.SetInsertPoint(endBB); - PHINode *phiNode = Builder.CreatePHI(int32Ty, 2); - phiNode->addIncoming(Builder.getInt32(0), rngBB); - phiNode->addIncoming(randUint64ILoad, entryBB); - Value *shl = Builder.CreateShl(phiNode, 3, "", false, true); - Value *sext = Builder.CreateSExt(shl, int64Ty); - std::vector gepIndices = {Builder.getInt64(0), sext}; - ArrayType *arrayType = ArrayType::get(int8Ty, shishua_buffer_size); - // create a GEP with opaque pointer type - Value *gep = Builder.CreateGEP(arrayType, rand_uint64_buffer, gepIndices); - Value *bitcast5 = Builder.CreateBitCast(gep, int64Ty->getPointerTo()); - Value *load = Builder.CreateLoad(int64Ty, bitcast5); - Value *add5 = Builder.CreateNSWAdd(phiNode, nbBytesRequestedArg); - auto *store = Builder.CreateStore(add5, rand_uint64_i); - store->setAlignment(Align(4)); - Builder.CreateRet(load); - - Builder.SetInsertPoint(caller); - std::vector args = {Builder.getInt32(nbBytesRequested)}; - *rand = Builder.CreateCall(shishua, args); - errs() << "Shishua call inserted\n"; - errs() << "rand: " << **rand << "\n"; - } - - void checkAVX2Support() { - if (not cpuTargetInfo("+avx2")) { - errs() << "AVX2 is not supported on this machine\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - bool cpuTargetInfo(const std::string &feature) { - static bool print = false; - // Detect the host CPU - llvm::StringRef CPU = llvm::sys::getHostCPUName(); - if (not print) { - errs() << "CPU: " << CPU << "\n"; - } - llvm::SubtargetFeatures Features; - llvm::StringMap HostFeatures; - if (llvm::sys::getHostCPUFeatures(HostFeatures)) { - for (auto &Feature : HostFeatures) { - Features.AddFeature(Feature.first(), Feature.second); - } - } - - // Check for AVX2 feature - auto features = Features.getFeatures(); - if (not print) { - errs() << "Features: "; - for (auto &f : features) { - errs() << f << " "; - } - errs() << "\n"; - } - - bool HasFeature = - std::find(features.begin(), features.end(), feature) != features.end(); - - if (not print) { - llvm::outs() << feature << " support: " << (HasFeature ? "Yes" : "No") - << "\n"; - } - print = true; - return HasFeature; - } - - void insertRandUint64Call(IRBuilder<> &Builder, Instruction *I, - Value **rand) { - if (VfclibInstRNG == "xoroshiro") { - insertRandUint64XoroshiroCall(Builder, I, rand); - } else if (VfclibInstRNG == "shishua") { - checkAVX2Support(); - insertRandUint64ShishuaCall(Builder, I, rand); - } else { - errs() << "Unsupported RNG function: " << VfclibInstRNG << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - // start get_rand_uint64 - // %3 = load i64, i64* @rng_state.0, align 16, !tbaa !5 - // %4 = load i64, i64* @rng_state.1, align 16, !tbaa !5 - // %5 = add i64 %4, %3 - // %6 = shl i64 %5, 17 - // %7 = lshr i64 %5, 47 - // %8 = add i64 %7, %3 - // %9 = or i64 %8, %6 - // %10 = xor i64 %4, %3 - // %11 = shl i64 %3, 49 - // %12 = lshr i64 %3, 15 - // %13 = xor i64 %10, %12 - // %14 = shl i64 %10, 21 - // %15 = xor i64 %13, %14 - // %16 = or i64 %15, %11 - // store i64 %16, i64* @rng_state.0, align 16, !tbaa !5 - // %17 = call i64 @llvm.fshl.i64(i64 %10, i64 %10, i64 28) #12 - // store i64 %17, i64* @rng_state.1, align 16, !tbaa !5 - // end get_rand_uint64 - // start get_rand_double01 - // %18 = lshr i64 %9, 12 - // %19 = or i64 %18, 4607182418800017408 - // %20 = bitcast i64 %19 to double - // %21 = fadd double %20, -1.000000e+00 - // end get_rand_double01 - Value *insertRandDouble01Call(IRBuilder<> &Builder, Instruction *I) { - Module *M = I->getModule(); - Type *uint64Ty = Type::getInt64Ty(M->getContext()); - - // Get the global variables, set true to search for local variables - GlobalVariable *rng_state0 = M->getGlobalVariable("rng_state.0", true); - GlobalVariable *rng_state1 = M->getGlobalVariable("rng_state.1", true); - - Value *rand = nullptr; - insertRandUint64XoroshiroCall(Builder, I, &rand); - - // Get rand double between 0 and 1 - Value *lshr3 = Builder.CreateLShr(rand, 12); - Value *or4 = Builder.CreateOr(lshr3, Builder.getInt64(4607182418800017408)); - Type *bitCastTy = Type::getDoubleTy(M->getContext()); - Value *bitcast = Builder.CreateBitCast(or4, bitCastTy); - Value *mone = ConstantFP::get(Type::getDoubleTy(M->getContext()), -1.0); - Value *fadd = Builder.CreateFAdd(bitcast, mone); - return fadd; - } - - Value *insertVectorizeRandDouble01Call(IRBuilder<> &Builder, Instruction *I) { - VECTOR_TYPE *VT = dyn_cast(I->getType()); - int size = VT->getNumElements(); - // create undef value with the same type as the vector - Value *rand_double01 = UndefValue::get(VT); - std::vector elementsVec; - for (int i = 0; i < size; i++) { - elementsVec.push_back(insertRandDouble01Call(Builder, I)); - if (VT->getElementType()->isFloatTy()) { - elementsVec[i] = - Builder.CreateFPCast(elementsVec[i], VT->getElementType()); - } - rand_double01 = Builder.CreateInsertElement(rand_double01, elementsVec[i], - Builder.getInt32(i), - "insert" + std::to_string(i)); - } - return rand_double01; - } - - void insertVectorizeRandUint64Call(IRBuilder<> &Builder, Instruction *I, - Value **rand_uint64) { - VECTOR_TYPE *fpVT = dyn_cast(I->getType()); - int size = fpVT->getNumElements(); - VECTOR_TYPE *iVT = GET_VECTOR_TYPE(Type::getInt64Ty(I->getContext()), size); - // create undef value with the same type as the vector - *rand_uint64 = UndefValue::get(iVT); - std::vector elementsVec; - Value *rand = nullptr; - for (int i = 0; i < size; i++) { - insertRandUint64Call(Builder, I, &rand); - elementsVec.push_back(rand); - *rand_uint64 = Builder.CreateInsertElement(*rand_uint64, elementsVec[i], - Builder.getInt32(i), - "insert" + std::to_string(i)); - } - } - - double c99hextodouble(const char *hex) { - char *end; - double d = strtod(hex, &end); - if (*end != '\0') { - fprintf(stderr, "Invalid hex string: %s\n", hex); - exit(1); - } - return d; - } - Function *createFunction(FunctionType *functionType, GlobalValue::LinkageTypes linkage, Module &M, const std::string &name) { @@ -1308,1158 +491,82 @@ struct VfclibInst : public ModulePass { #endif } - Function *getOrCreateSyscallFunction(Module &M) { - Type *int64Ty = Type::getInt64Ty(M.getContext()); - Function *syscallF = M.getFunction("syscall"); - if (syscallF == nullptr) { - FunctionType *syscallTy = FunctionType::get(int64Ty, {int64Ty}, true); - syscallF = - createFunction(syscallTy, GlobalValue::ExternalLinkage, M, "syscall"); - } - return syscallF; - } - - Function *getOrInsertFunction(Module &M, StringRef name, - FunctionType *functionType) { -#if LLVM_VERSION_MAJOR < 9 - Constant *cst = M.getOrInsertFunction(name, functionType); - Function *function = dyn_cast(cst); -#else - FunctionCallee callee = M.getOrInsertFunction(name, functionType); - Function *function = dyn_cast(callee.getCallee()); -#endif - return function; - } - - Function *getOrInsertFunction(Module *M, StringRef name, - FunctionType *functionType) { -#if LLVM_VERSION_MAJOR < 9 - Constant *cst = M->getOrInsertFunction(name, functionType); - Function *function = dyn_cast(cst); -#else - FunctionCallee callee = M->getOrInsertFunction(name, functionType); - Function *function = dyn_cast(callee.getCallee()); -#endif - return function; - } - - Function *getOrInsertGettimeofdayFunction(Module &M, IRBuilder<> &Builder, - StructType *TimevalTy, - StructType *TimezoneTy) { - Type *int64Ty = Type::getInt64Ty(Builder.getContext()); - Type *int8PtrTy = Type::getInt8PtrTy(Builder.getContext()); - - // retrieve the struct timeval and timezone type - - std::vector typeGetTimeOfDay = {PointerType::getUnqual(TimevalTy), - PointerType::getUnqual(TimezoneTy)}; - - Type *int32Ty = Type::getInt32Ty(Builder.getContext()); - FunctionType *GettimeofdayFuncType = - FunctionType::get(int32Ty, typeGetTimeOfDay, false); - - Function *GettimeofdayFunc = - getOrInsertFunction(M, "gettimeofday", GettimeofdayFuncType); - - return GettimeofdayFunc; - } - - void insertGettimeofdayCall(Module &M, IRBuilder<> &Builder, - StructType *TimevalTy, StructType *TimezoneTy, - Value *AllocaTimeval) { - Function *GettimeofdayFunc = - getOrInsertGettimeofdayFunction(M, Builder, TimevalTy, TimezoneTy); - - bool bySyscall = false; - if (GettimeofdayFunc == nullptr) { - GettimeofdayFunc = getOrCreateSyscallFunction(M); - bySyscall = true; - } - - Type *int64Ty = Type::getInt64Ty(Builder.getContext()); - Type *int8PtrTy = Type::getInt8PtrTy(Builder.getContext()); - - Value *BitcastTimeval = - Builder.CreateBitCast(AllocaTimeval, int8PtrTy, "bitcast_timeval"); - - Constant *nullValue = Constant::getNullValue(TimezoneTy->getPointerTo()); - - if (bySyscall) { - Constant *gettimeofdaySyscallId = - ConstantInt::get(int64Ty, SYS_gettimeofday); - std::vector args = {gettimeofdaySyscallId, AllocaTimeval, - nullValue}; - Value *Syscall = Builder.CreateCall(GettimeofdayFunc, args); - } else { - std::vector args = {AllocaTimeval, nullValue}; - Value *Gettimeofday = Builder.CreateCall(GettimeofdayFunc, args); - } - } - - GlobalVariable *getGVAlreadyInitialized(Module &M, IRBuilder<> &Builder) { - Type *boolTy = Type::getInt1Ty(Builder.getContext()); - Constant *falseConst = ConstantInt::get(boolTy, 0); - GlobalVariable *already_initialized = - M.getGlobalVariable("already_initialized"); - if (already_initialized == nullptr) { - already_initialized = new GlobalVariable( - M, - /* type */ boolTy, - /* isConstant */ false, - /* linkage */ GlobalValue::InternalLinkage, - /* initializer */ - falseConst, - /* name */ "already_initialized", - /* insertbefore */ nullptr, - /* threadmode */ - GlobalValue::ThreadLocalMode::GeneralDynamicTLSModel, - /* addresspace */ 0, - /* isExternallyInitialized */ false); - } - return already_initialized; - } - - GlobalVariable *getGVRNGState(Module &M, const std::string &name) { - - Type *int64Ty = Type::getInt64Ty(M.getContext()); - Type *rngType = nullptr; - Constant *zero = nullptr; - if (VfclibInstRNG == "xoroshiro") { - rngType = int64Ty; - zero = ConstantInt::get(int64Ty, 0); - } else if (VfclibInstRNG == "shishua") { - rngType = GET_VECTOR_TYPE(int64Ty, 4); - zero = ConstantInt::get(rngType, 0); - } else { - errs() << "Unsupported RNG function: " << VfclibInstRNG << "\n"; + Function *getFunction(Module *M, StringRef name) { + if (demangledShortNamesToMangled.find(name.str()) == + demangledShortNamesToMangled.end()) { + errs() << "Function " << name << " not found\n"; report_fatal_error("libVFCInstrument fatal error"); } - GlobalVariable *rng_state = M.getGlobalVariable(name, true); - if (rng_state == nullptr) { - rng_state = new GlobalVariable( - M, - /* type */ rngType, - /* isConstant */ false, - /* linkage */ GlobalValue::InternalLinkage, - /* initializer */ zero, - /* name */ name, - /* insertbefore */ nullptr, - /* threadmode */ - GlobalValue::ThreadLocalMode::GeneralDynamicTLSModel, - /* addresspace */ 0, - /* isExternallyInitialized */ false); - } - return rng_state; - } - - GlobalVariable *getGVRNGState(Module *M, const std::string &name) { - return getGVRNGState(*M, name); - } - - GlobalVariable *getGVRandUint64Counter(Module &M) { - Type *int32Ty = Type::getInt32Ty(M.getContext()); - const std::string name = "shishua_buffer_index"; - GlobalVariable *rand_uint64_i = M.getGlobalVariable(name, true); - if (rand_uint64_i == nullptr) { - rand_uint64_i = new GlobalVariable( - M, - /* type */ int32Ty, - /* isConstant */ false, - /* linkage */ GlobalValue::InternalLinkage, - /* initializer */ ConstantInt::get(int32Ty, 0), - /* name */ - name, - /* insertbefore */ nullptr, - /* threadmode */ GlobalValue::ThreadLocalMode::GeneralDynamicTLSModel, - /* addresspace */ 0, - /* isExternallyInitialized */ false); - } - return rand_uint64_i; - } - - // @get_rand_uint64.buf = internal unnamed_addr global [32 x i8] - // zeroinitializer, align 32 - GlobalVariable *getGVRandUint64Buffer(Module &M) { - const std::string name = "shishua_buffer"; - Type *uint8Ty = Type::getInt8Ty(M.getContext()); - ArrayType *arrayType = ArrayType::get(uint8Ty, shishua_buffer_size); - GlobalVariable *rand_uint64_buf = M.getGlobalVariable(name, true); - if (rand_uint64_buf == nullptr) { - rand_uint64_buf = new GlobalVariable( - M, - /* type */ arrayType, - /* isConstant */ false, - /* linkage */ GlobalValue::InternalLinkage, - /* initializer */ ConstantAggregateZero::get(arrayType), - /* name */ name, - /* insertbefore */ nullptr, - /* threadmode */ GlobalValue::ThreadLocalMode::GeneralDynamicTLSModel, - /* addresspace */ 0, - /* isExternallyInitialized */ false); - auto align = Align(32); - rand_uint64_buf->setAlignment(align); - } - return rand_uint64_buf; - } - - Value *insertNextSeed(Module &M, IRBuilder<> &Builder, Value *I) { - Type *int64Ty = Type::getInt64Ty(M.getContext()); - Constant *nextSeedCst1 = ConstantInt::get(int64Ty, -7046029254386353131); - Value *add = Builder.CreateAdd(I, nextSeedCst1); - Value *lshr1 = Builder.CreateLShr(add, 30); - Value *xor3 = Builder.CreateXor(lshr1, add); - Constant *nextSeedCst2 = ConstantInt::get(int64Ty, -4658895280553007687); - Value *mul1 = Builder.CreateMul(xor3, nextSeedCst2); - Value *lshr2 = Builder.CreateLShr(mul1, 27); - Value *xor4 = Builder.CreateXor(lshr2, mul1); - Constant *nextSeedCst3 = ConstantInt::get(int64Ty, -7723592293110705685); - Value *mul2 = Builder.CreateMul(xor4, nextSeedCst3); - Value *lshr3 = Builder.CreateLShr(mul2, 31); - Value *xor5 = Builder.CreateXor(lshr3, mul2); - return xor5; - } - - Value *insertNextSeedVectorize(Module &M, IRBuilder<> &Builder, Value *I) { - Type *int64Ty = Type::getInt64Ty(M.getContext()); - VECTOR_TYPE *VT = dyn_cast(I->getType()); - auto count = VT->getNumElements(); - auto size = CREATE_VECTOR_ELEMENT_COUNT(count); - auto *nextSeedCst1 = ConstantInt::get(int64Ty, -7046029254386353131); - auto *nextSeedCst1Vec = ConstantVector::getSplat(size, nextSeedCst1); - auto *add = Builder.CreateAdd(I, nextSeedCst1Vec); - auto *cst30 = ConstantVector::getSplat(size, Builder.getInt64(30)); - auto *lshr1 = Builder.CreateLShr(add, cst30); - auto *xor3 = Builder.CreateXor(lshr1, add); - auto *nextSeedCst2 = ConstantInt::get(int64Ty, -4658895280553007687); - auto *nextSeedCst2Vec = ConstantVector::getSplat(size, nextSeedCst2); - auto *mul1 = Builder.CreateMul(xor3, nextSeedCst2Vec); - auto *cst27 = ConstantVector::getSplat(size, Builder.getInt64(27)); - auto *lshr2 = Builder.CreateLShr(mul1, cst27); - auto *xor4 = Builder.CreateXor(lshr2, mul1); - auto *nextSeedCst3 = ConstantInt::get(int64Ty, -7723592293110705685); - auto *nextSeedCst3Vec = ConstantVector::getSplat(size, nextSeedCst3); - auto *mul2 = Builder.CreateMul(xor4, nextSeedCst3Vec); - auto *cst31 = ConstantVector::getSplat(size, Builder.getInt64(31)); - auto *lshr3 = Builder.CreateLShr(mul2, cst31); - auto *xor5 = Builder.CreateXor(lshr3, mul2); - return xor5; - } - - // clang-format off - // start init - // ; Function Attrs: nounwind uwtable - // define internal void @init() #3 { - // %1 = alloca %struct.timeval, align 8 - // %2 = load i1, i1* @already_initialized, align 1 - // br i1 %2, label %22, label %3 - // 3: ; preds = %0 - // store i1 true, i1* @already_initialized, - // align 1 - // %4 = bitcast %struct.timeval* %1 to i8* - // call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %4) #12 - // %5 = call i32 @gettimeofday(%struct.timeval* noundef nonnull %1, i8*noundef null) #12 - // %6 = getelementptr inbounds %struct.timeval,%struct.timeval* %1, i64 0, i32 0 - // %7 = load i64, i64* %6, align 8, !tbaa !9 - // %8 = getelementptr inbounds %struct.timeval, %struct.timeval* %1, i64 0, i32 1 - // %9 = load i64, i64* %8, align 8, !tbaa !11 %10 = xor i64 %9, %7 - // %11 = tail call i64 (i64, ...) @syscall(i64 noundef 186) #12 - // %12 = xor i64 %10, %11 - // %13 = add i64 %12, -7046029254386353131 - // %14 = lshr i64 %13, 30 - // %15 = xor i64 %14, %13 - // %16 = mul i64 %15, -4658895280553007687 - // %17 = lshr i64 %16, 27 - // %18 = xor i64 %17, %16 - // %19 = mul i64 %18, -7723592293110705685 - // %20 = lshr i64 %19, 31 - // %21 = xor i64 %20, %19 - // store i64 %21, i64* @rng_state.0, align 16, !tbaa !5 - // store i64 %21, i64* @rng_state.1, align 16, !tbaa !5 - // call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %4) #12 - // br label %22 - // 22: ; preds = %0, %3 - // ret void - // } - // clang-format on - Function *getOrCreateInitRNGFunction(Module &M) { - IRBuilder<> Builder(M.getContext()); - - const std::string function_name = "_sr_init_rng"; - Function *function = M.getFunction(function_name); - - if (function == nullptr or function->empty()) { - Type *voidTy = Type::getVoidTy(Builder.getContext()); - FunctionType *funcType = FunctionType::get(voidTy, {}, false); - - if (function == nullptr) { - function = createFunction(funcType, GlobalValue::InternalLinkage, M, - "_sr_init_rng"); - } - - BasicBlock *BB = - BasicBlock::Create(Builder.getContext(), "entry", function); - Builder.SetInsertPoint(BB); - - GlobalVariable *already_initialized = getGVAlreadyInitialized(M, Builder); - - Type *boolTy = Type::getInt1Ty(Builder.getContext()); - Value *already_initialized_load = - Builder.CreateLoad(boolTy, already_initialized); - BasicBlock *initBB = - BasicBlock::Create(Builder.getContext(), "init", function); - BasicBlock *retBB = - BasicBlock::Create(Builder.getContext(), "ret", function); - - Builder.CreateCondBr(already_initialized_load, retBB, initBB); - - Builder.SetInsertPoint(initBB); - Value *store = Builder.CreateStore( - ConstantInt::get(Type::getInt1Ty(Builder.getContext()), 1), - already_initialized); - - Type *int8PtrTy = Type::getInt8PtrTy(Builder.getContext()); - Type *int32Ty = Type::getInt32Ty(Builder.getContext()); - Type *int64Ty = Type::getInt64Ty(Builder.getContext()); - Value *seed = nullptr; - - if (VfclibSeed == -1) { - StructType *TimevalTy = - StructType::create(Builder.getContext(), "struct.timeval"); - std::vector Elements(2, int64Ty); - TimevalTy->setBody(Elements, /*isPacked=*/false); - - StructType *TimezoneTy = - StructType::create(Builder.getContext(), "struct.timezone"); - TimevalTy->setBody(Elements, /*isPacked=*/false); - - AllocaInst *AllocaTimeval = - Builder.CreateAlloca(TimevalTy, nullptr, "timeval"); - insertGettimeofdayCall(M, Builder, TimevalTy, TimezoneTy, - AllocaTimeval); - // Load timeval->tv_sec - Value *TvSecPtr = - Builder.CreateStructGEP(TimevalTy, AllocaTimeval, 0, "tv_sec_ptr"); - Value *tvsec = Builder.CreateLoad(int64Ty, TvSecPtr, "tv_sec"); - - // Load timeval->tv_usec - Value *TvUsecPtr = - Builder.CreateStructGEP(TimevalTy, AllocaTimeval, 1, "tv_usec_ptr"); - Value *tvusec = Builder.CreateLoad(int64Ty, TvUsecPtr, "tv_usec"); - Function *syscallF = getOrCreateSyscallFunction(M); - Constant *gettidSyscallId = ConstantInt::get(int64Ty, SYS_gettid); - Value *syscall = Builder.CreateCall(syscallF, gettidSyscallId); - Value *xor1 = Builder.CreateXor(tvsec, tvusec); - seed = Builder.CreateXor(xor1, syscall); - } else { - seed = ConstantInt::get(int64Ty, VfclibSeed); - } - - if (VfclibInstRNG == "xoroshiro") { - Value *nextSeed = insertNextSeed(M, Builder, seed); - GlobalVariable *rng_state0 = getGVRNGState(M, "rng_state.0"); - Builder.CreateStore(nextSeed, rng_state0); - nextSeed = insertNextSeed(M, Builder, nextSeed); - GlobalVariable *rng_state1 = getGVRNGState(M, "rng_state.1"); - Builder.CreateStore(nextSeed, rng_state1); - } else if (VfclibInstRNG == "shishua") { - - Value *nextSeed1 = insertNextSeed(M, Builder, seed); - Value *nextSeed2 = insertNextSeed(M, Builder, nextSeed1); - Value *nextSeed3 = insertNextSeed(M, Builder, nextSeed2); - Value *nextSeed4 = insertNextSeed(M, Builder, nextSeed3); - // create a vector with the 4 seeds - VECTOR_TYPE *VT = GET_VECTOR_TYPE(int64Ty, 4); - Value *seeds = UndefValue::get(VT); - seeds = - Builder.CreateInsertElement(seeds, nextSeed1, Builder.getInt32(0)); - seeds = - Builder.CreateInsertElement(seeds, nextSeed2, Builder.getInt32(1)); - seeds = - Builder.CreateInsertElement(seeds, nextSeed3, Builder.getInt32(2)); - seeds = - Builder.CreateInsertElement(seeds, nextSeed4, Builder.getInt32(3)); - - GlobalVariable *rng_state0 = getGVRNGState(M, "rng_state.0"); - Builder.CreateStore(seeds, rng_state0); - seeds = insertNextSeedVectorize(M, Builder, seeds); - GlobalVariable *rng_state1 = getGVRNGState(M, "rng_state.1"); - Builder.CreateStore(seeds, rng_state1); - GlobalVariable *rng_state2 = getGVRNGState(M, "rng_state.2"); - seeds = insertNextSeedVectorize(M, Builder, seeds); - Builder.CreateStore(seeds, rng_state2); - GlobalVariable *rng_state3 = getGVRNGState(M, "rng_state.3"); - seeds = insertNextSeedVectorize(M, Builder, seeds); - Builder.CreateStore(seeds, rng_state3); - } - - Builder.CreateBr(retBB); - Builder.SetInsertPoint(retBB); - Builder.CreateRetVoid(); - Builder.ClearInsertionPoint(); - } - return function; - } - - // clang-format off - // %23 = fsub double %22, %0 - // %24 = fsub double %22, %23 - // %25 = fsub double %0, %24 - // %26 = fsub double %1, %23 - // %27 = fadd double %26, %25 - // clang-format on - void insertTwoSum(IRBuilder<> &Builder, Value *a, Value *b, Value **sigma, - Value **tau) { - Value *sub1 = Builder.CreateFSub(*sigma, a); - Value *sub2 = Builder.CreateFSub(*sigma, sub1); - Value *sub3 = Builder.CreateFSub(a, sub2); - Value *sub4 = Builder.CreateFSub(b, sub1); - *tau = Builder.CreateFAdd(sub4, sub3, "tau"); - } - - // clang-format off - // %22 = fmul double %0, %1 - // %23 = fneg double %22 - // %24 = call double @llvm.fma.f64(double %0, double %1, double %23) #12 - // clang-format on - void insertTwoProduct(IRBuilder<> &Builder, Value *a, Value *b, Value **sigma, - Value **tau) { - *sigma = Builder.CreateFMul(a, b, "sigma"); - Value *neg = Builder.CreateFNeg(*sigma); - std::vector args = {a, b, neg}; - Type *retTy = a->getType(); - *tau = CREATE_FMA_CALL(Builder, retTy, args); - } - - // start sr_round_b64 - // %22 = fadd double %0, %1 ; actual fadd - // %23 = fsub double %22, %0 - // %24 = fsub double %22, %23 - // %25 = fsub double %0, %24 - // %26 = fsub double %1, %23 - // %27 = fadd double %26, %25 - // %28 = bitcast double %22 to i64 - // %29 = and i64 %28, 9218868437227405312 - // %30 = bitcast i64 %29 to double - // %31 = fmul double %30, 0x3CB0000000000000 - // %32 = fmul double %31, %21 - // %33 = fadd double %27, %32 - // %34 = bitcast double %33 to i64 - // %35 = and i64 %34, 9223372036854775807 - // %36 = bitcast i64 %35 to double - // %37 = bitcast double %31 to i64 - // %38 = and i64 %37, 9223372036854775807 - // %39 = bitcast i64 %38 to double - // %40 = fcmp ult double %36, %39 - // %41 = select i1 %40, double 0.000000e+00, double %31 - // %42 = fadd double %22, %41 - // ret double %42 - // end sr_round_b64 - Value *insertSRRounding(IRBuilder<> &Builder, Value *sigma, Value *tau, - Value *rand_double01) { - - Type *srcTy = sigma->getType(); - bool isVector = sigma->getType()->isVectorTy(); - VECTOR_TYPE *VT = dyn_cast(srcTy); - Type *fpTy = (isVector) ? VT->getElementType() : srcTy; - auto size = (isVector) ? VT->getNumElements() : 1; - - double ulpValue = 0.0; - Value *ulp = nullptr; - Constant *getExponent = nullptr, *getSign = nullptr; - - if (fpTy->isFloatTy()) { - ulpValue = c99hextodouble("0x1.0p-23"); - getExponent = Builder.getInt32(0x7f800000); - getSign = Builder.getInt32(0x7fffffff); - } else { - ulpValue = c99hextodouble("0x1.0p-52"); - getExponent = Builder.getInt64(0x7ff0000000000000); - getSign = Builder.getInt64(0x7fffffffffffffff); - } - - ulp = ConstantFP::get(srcTy, ulpValue); - - if (isVector) { - auto count = CREATE_VECTOR_ELEMENT_COUNT(size); - getExponent = ConstantVector::getSplat(count, getExponent); - getSign = ConstantVector::getSplat(count, getSign); - } - - Type *fpAsIntTy = getFPAsIntType(fpTy); - if (isVector) { - fpAsIntTy = GET_VECTOR_TYPE(fpAsIntTy, size); - } - - // if a float, cast rand_double01 to float - if (fpTy->isFloatTy()) { - rand_double01 = Builder.CreateFPCast(rand_double01, srcTy, "float cast"); - } - - // depending on the src type - Value *bitCast = Builder.CreateBitCast(sigma, fpAsIntTy); - Value *and1 = Builder.CreateAnd(bitCast, getExponent); - Value *bitCast2 = Builder.CreateBitCast(and1, srcTy); - Value *fmul = Builder.CreateFMul(bitCast2, ulp); - - Value *fmul2 = Builder.CreateFMul(fmul, rand_double01, "ulp * z"); - - Value *fadd = Builder.CreateFAdd(tau, fmul2); - Value *bitCast3 = Builder.CreateBitCast(fadd, fpAsIntTy); - - Value *and2 = Builder.CreateAnd(bitCast3, getSign); - Value *bitCast4 = Builder.CreateBitCast(and2, srcTy); - Value *bitCast5 = Builder.CreateBitCast(fmul, fpAsIntTy); - Value *and3 = Builder.CreateAnd(bitCast5, getSign); - Value *bitCast6 = Builder.CreateBitCast(and3, srcTy); - Value *ult = Builder.CreateFCmpULT(bitCast4, bitCast6); - - Value *zero = nullptr; - if (isVector) { - auto count = CREATE_VECTOR_ELEMENT_COUNT(size); - zero = ConstantVector::getSplat(count, ConstantFP::get(fpTy, 0.0)); - } else { - zero = ConstantFP::get(fpTy, 0.0); - } - - Value *select = Builder.CreateSelect(ult, zero, fmul); - Value *fadd2 = Builder.CreateFAdd(sigma, select); - return fadd2; - } - - // ; Function Attrs: mustprogress nofree nosync nounwind uwtable willreturn - // define dso_local double @add2_double(double noundef %0, double noundef - // %1) local_unnamed_addr #2 { start get_rand_uint64 - // %3 = load i64, i64* @rng_state.0, align 16, !tbaa !5 - // %4 = load i64, i64* @rng_state.1, align 16, !tbaa !5 - // %5 = add i64 %4, %3 - // %6 = shl i64 %5, 17 - // %7 = lshr i64 %5, 47 - // %8 = add i64 %7, %3 - // %9 = or i64 %8, %6 - // %10 = xor i64 %4, %3 - // %11 = shl i64 %3, 49 - // %12 = lshr i64 %3, 15 - // %13 = xor i64 %10, %12 - // %14 = shl i64 %10, 21 - // %15 = xor i64 %13, %14 - // %16 = or i64 %15, %11 - // store i64 %16, i64* @rng_state.0, align 16, !tbaa !5 - // %17 = call i64 @llvm.fshl.i64(i64 %10, i64 %10, i64 28) #12 - // store i64 %17, i64* @rng_state.1, align 16, !tbaa !5 - // end get_rand_uint64 - // start get_rand_double01 - // %18 = lshr i64 %9, 12 - // %19 = or i64 %18, 4607182418800017408 - // %20 = bitcast i64 %19 to double - // %21 = fadd double %20, -1.000000e+00 - // end get_rand_double01 - // %22 = fadd double %0, %1 ; actual fadd - // start twosum - // %23 = fsub double %22, %0 - // %24 = fsub double %22, %23 - // %25 = fsub double %0, %24 - // %26 = fsub double %1, %23 - // %27 = fadd double %26, %25 - // end twosum - // start sr_round_b64 - // %28 = bitcast double %22 to i64 - // %29 = and i64 %28, 9218868437227405312 - // %30 = bitcast i64 %29 to double - // %31 = fmul double %30, 0x3CB0000000000000 - // %32 = fmul double %31, %21 - // %33 = fadd double %27, %32 - // %34 = bitcast double %33 to i64 - // %35 = and i64 %34, 9223372036854775807 - // %36 = bitcast i64 %35 to double - // %37 = bitcast double %31 to i64 - // %38 = and i64 %37, 9223372036854775807 - // %39 = bitcast i64 %38 to double - // %40 = fcmp ult double %36, %39 - // %41 = select i1 %40, double 0.000000e+00, double %31 - // %42 = fadd double %22, %41 - // ret double %42 - // end sr_round_b64 - // } - Value *createSrAdd(IRBuilder<> &Builder, Instruction *I, - Function::arg_iterator args) { - - Value *a = static_cast(&args[0]); - Value *b = static_cast(&args[1]); - - Value *rand_double01 = nullptr; - bool isVector = a->getType()->isVectorTy(); - if (isVector) { - rand_double01 = insertVectorizeRandDouble01Call(Builder, I); - } else { - rand_double01 = insertRandDouble01Call(Builder, I); - } - - Value *tau = nullptr, *sigma = nullptr; - sigma = Builder.CreateFAdd(a, b, "sigma"); - insertTwoSum(Builder, a, b, &sigma, &tau); - Value *sr_round = insertSRRounding(Builder, sigma, tau, rand_double01); - return sr_round; - } - - Value *createSrSub(IRBuilder<> &Builder, Instruction *I, - Function::arg_iterator args) { - Value *a = static_cast(&args[0]); - Value *b = static_cast(&args[1]); - - bool isVector = a->getType()->isVectorTy(); - Value *rand_double01 = nullptr; - if (isVector) { - rand_double01 = insertVectorizeRandDouble01Call(Builder, I); - } else { - rand_double01 = insertRandDouble01Call(Builder, I); - } - - Value *tau = nullptr, *sigma = nullptr; - sigma = Builder.CreateFSub(a, b, "sigma"); - b = Builder.CreateFNeg(b); - insertTwoSum(Builder, a, b, &sigma, &tau); - Value *sr_round = insertSRRounding(Builder, sigma, tau, rand_double01); - return sr_round; + return M->getFunction(demangledShortNamesToMangled[name.str()]); } - // clang-format off - // Function Attrs: mustprogress nofree nosync nounwind uwtable willreturn - // define dso_local double @mul2_double(double noundef %0, double noundef %1) local_unnamed_addr #2 { - // start get_rand_uint64 - // %3 = load i64, i64* @rng_state.0, align 16, !tbaa !5 - // %4 = load i64, i64* @rng_state.1, align 16, !tbaa !5 - // %5 = add i64 %4, %3 - // %6 = shl i64 %5, 17 - // %7 = lshr i64 %5, 47 - // %8 = add i64 %7, %3 - // %9 = or i64 %8, %6 - // %10 = xor i64 %4, %3 - // %11 = shl i64 %3, 49 - // %12 = lshr i64 %3, 15 - // %13 = xor i64 %10, %12 - // %14 = shl i64 %10, 21 - // %15 = xor i64 %13, %14 - // %16 = or i64 %15, %11 - // store i64 %16, i64* @rng_state.0, align 16, !tbaa !5 - // %17 = call i64 @llvm.fshl.i64(i64 %10, i64 %10, i64 28) #12 - // store i64 %17, i64* @rng_state.1, align 16, !tbaa !5 - // end get_rand_uint64 - // start get_rand_double01 - // %18 = lshr i64 %9, 12 - // %19 = or i64 %18, 4607182418800017408 - // %20 = bitcast i64 %19 to double - // %21 = fadd double %20, -1.000000e+00 - // end get_rand_double01 - // start twoproduct - // %22 = fmul double %0, %1 - // %23 = fneg double %22 - // %24 = call double @llvm.fma.f64(double %0, double %1, double %23) #12 - // end twoproduct - // start sr_round_b64 - // %25 = bitcast double %22 to i64 - // %26 = and i64 %25, 9218868437227405312 - // %27 = bitcast i64 %26 to double - // %28 = fmul double %27, 0x3CB0000000000000 - // %29 = fmul double %28, %21 - // %30 = fadd double %24, %29 - // %31 = bitcast double %30 to i64 - // %32 = and i64 %31, 9223372036854775807 - // %33 = bitcast i64 %32 to double - // %34 = bitcast double %28 to i64 - // %35 = and i64 %34, 9223372036854775807 - // %36 = bitcast i64 %35 to double - // %37 = fcmp ult double %33, %36 - // %38 = select i1 %37, double 0.000000e+00, double %28 - // %39 = fadd double %22, %38 - // ret double %39 - // end sr_round_b64 - // } - // } - // clang-format on - Value *createSrMul(IRBuilder<> &Builder, Instruction *I, - Function::arg_iterator args) { - Value *a = static_cast(&args[0]); - Value *b = static_cast(&args[1]); - - Value *sigma = nullptr, *tau = nullptr; - - bool isVector = a->getType()->isVectorTy(); - Value *rand_double01 = nullptr; - if (isVector) { - rand_double01 = insertVectorizeRandDouble01Call(Builder, I); - } else { - rand_double01 = insertRandDouble01Call(Builder, I); - } - - insertTwoProduct(Builder, a, b, &sigma, &tau); - Value *sr_round = insertSRRounding(Builder, sigma, tau, rand_double01); - return sr_round; - } - - Value *createSrDiv(IRBuilder<> &Builder, Instruction *I, - Function::arg_iterator args) { - Value *a = static_cast(&args[0]); - Value *b = static_cast(&args[1]); - - Value *sigma = nullptr, *tau = nullptr; + std::string get_mangled_name(Function *F) { + // Create a Mangler + llvm::Mangler Mang; - Value *rand_double01 = nullptr; - bool isVector = a->getType()->isVectorTy(); - if (isVector) { - rand_double01 = insertVectorizeRandDouble01Call(Builder, I); - } else { - rand_double01 = insertRandDouble01Call(Builder, I); - } - - // clang-format off - // %22 = fdiv double %0, %1 - // %23 = fneg double %22 - // %24 = call double @llvm.fma.f64(double %23, double %1, double %0) - // %25 = fdiv double %24, %1 - // clang-format on - sigma = Builder.CreateFDiv(a, b); - Value *neg = Builder.CreateFNeg(sigma); - std::vector args_fma = {neg, b, a}; - Type *fmaRetType = a->getType(); - Value *fma = CREATE_FMA_CALL(Builder, fmaRetType, args_fma); - tau = Builder.CreateFDiv(fma, b); - Value *sr_round = insertSRRounding(Builder, sigma, tau, rand_double01); - return sr_round; + // Get the mangled name + std::string MangledName; + llvm::raw_string_ostream RawOS(MangledName); + Mang.getNameWithPrefix(RawOS, F, false); + return MangledName; } - Value *createSrFMA(IRBuilder<> &Builder, Instruction *I, - Function::arg_iterator args) { - Value *a = static_cast(&args[0]); - Value *b = static_cast(&args[1]); - Value *c = static_cast(&args[2]); - - Value *ph = nullptr, *pl = nullptr, *uh = nullptr, *ul = nullptr; - - bool isVector = a->getType()->isVectorTy(); - Value *rand_double01 = nullptr; - if (isVector) { - rand_double01 = insertVectorizeRandDouble01Call(Builder, I); + std::string get_demangled_name(const std::string &name) { + int status; + char *demangled = abi::__cxa_demangle(name.c_str(), 0, 0, &status); + if (status == 0) { + std::string demangled_name(demangled); + free(demangled); + return demangled_name; } else { - rand_double01 = insertRandDouble01Call(Builder, I); + return name; } - - insertTwoProduct(Builder, a, b, &ph, &pl); - uh = Builder.CreateFAdd(c, ph); - insertTwoSum(Builder, c, ph, &uh, &ul); - - std::vector args_sigma = {a, b, c}; - Type *retTy = a->getType(); - Value *sigma = CREATE_FMA_CALL(Builder, retTy, args_sigma); - Value *t = Builder.CreateFSub(uh, sigma); - Value *error = Builder.CreateFAdd(pl, ul); - Value *tau = Builder.CreateFAdd(error, t); - Value *sr_round = insertSRRounding(Builder, sigma, tau, rand_double01); - return sr_round; } - std::string getFunctionName(Fops opCode) { - switch (opCode) { - case Fops::FOP_ADD: - return "add2"; - case Fops::FOP_SUB: - return "sub2"; - case Fops::FOP_MUL: - return "mul2"; - case Fops::FOP_DIV: - return "div2"; - case Fops::FOP_FMA: - return "fma2"; - default: + std::string getFunctionName(Instruction *I, Fops opCode) { + if (opCode == FOP_IGNORE) { errs() << "Unsupported opcode: " << opCode << "\n"; report_fatal_error("libVFCInstrument fatal error"); } - } - - Value *createSrOp(IRBuilder<> &Builder, Instruction *I, - Function::arg_iterator args, Fops opCode) { - switch (opCode) { - case Fops::FOP_ADD: - return createSrAdd(Builder, I, args); - case Fops::FOP_SUB: - return createSrSub(Builder, I, args); - case Fops::FOP_MUL: - return createSrMul(Builder, I, args); - case Fops::FOP_DIV: - return createSrDiv(Builder, I, args); - case Fops::FOP_FMA: - return createSrFMA(Builder, I, args); - default: - errs() << "Unsupported opcode: " << opCode << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - } - - Value *getOrCreateSrFunction(IRBuilder<> &Builder, Instruction *I, - Fops opCode) { - Module *M = Builder.GetInsertBlock()->getParent()->getParent(); - Type *srcTy = I->getType(); - - std::string function_name = - getFunctionName(opCode) + "_" + getTypeName(srcTy); - - // if instruction is vector, update the name - if (srcTy->isVectorTy()) { - VectorType *VT = dyn_cast(srcTy); -#if LLVM_VERSION_MAJOR < 9 - unsigned int size = VT->getNumElements(); -#else - ElementCount EC = VT->getElementCount(); - unsigned int size = EC.getFixedValue(); -#endif - function_name += "v" + std::to_string(size); - } - - Function *function = I->getModule()->getFunction(function_name); - - if (function == nullptr or function->empty()) { - Type *voidTy = Type::getVoidTy(I->getContext()); - - FunctionType *funcType = nullptr; - if (opCode == Fops::FOP_FMA) { - funcType = FunctionType::get(srcTy, {srcTy, srcTy, srcTy}, false); - } else { - funcType = FunctionType::get(srcTy, {srcTy, srcTy}, false); - } - if (function == nullptr) { - function = Function::Create(funcType, Function::InternalLinkage, - function_name, M); - } - - BasicBlock *BB = BasicBlock::Create(I->getContext(), "entry", function); - Builder.SetInsertPoint(BB); - - Function::arg_iterator args = function->arg_begin(); - Value *sr_op = createSrOp(Builder, I, &*args, opCode); - Builder.CreateRet(sr_op); - Builder.ClearInsertionPoint(); - } - - // Check if this instruction is the last one in the block - if (I->isTerminator()) { - // Insert at the end of the block - Builder.SetInsertPoint(I->getParent()); - } else { - // Set insertion point after the given instruction - Builder.SetInsertPoint(I->getNextNode()); - } - - if (opCode == Fops::FOP_FMA) { - std::vector args = {I->getOperand(0), I->getOperand(1), - I->getOperand(2)}; - return Builder.CreateCall(function, args); - } else { - std::vector args = {I->getOperand(0), I->getOperand(1)}; - return Builder.CreateCall(function, args); - } - } - - Value *replaceArithmeticWithMCACall_SR(IRBuilder<> &Builder, Instruction *I, - std::set &users) { - Fops opCode = mustReplace(*I); - Value *value = getOrCreateSrFunction(Builder, I, opCode); - return value; - } - Value *getOrCreateUpDownFunction(IRBuilder<> &Builder, Instruction *I, - Fops opCode, std::set &users) { - Module *M = Builder.GetInsertBlock()->getParent()->getParent(); - Type *srcTy = I->getType(); - - std::string function_name = - getFunctionName(opCode) + "_" + getTypeName(srcTy) + "_updown"; - - // if instruction is vector, update the name - if (srcTy->isVectorTy()) { - VectorType *VT = dyn_cast(srcTy); -#if LLVM_VERSION_MAJOR < 9 - unsigned int size = VT->getNumElements(); -#else - ElementCount EC = VT->getElementCount(); - unsigned int size = EC.getFixedValue(); -#endif - function_name += "v" + std::to_string(size); - } + auto baseType = I->getType(); + std::string libname = VfclibInstMode; + if (VfclibInstMode == "up-down") + libname = "ud"; - Function *function = I->getModule()->getFunction(function_name); + // Print the mangled name + // llvm::outs() << "Mangled name: " << MangledName << "\n"; - if (function == nullptr or function->empty()) { - Type *voidTy = Type::getVoidTy(I->getContext()); - - FunctionType *funcType = nullptr; - if (opCode == Fops::FOP_FMA) { - funcType = FunctionType::get(srcTy, {srcTy, srcTy, srcTy}, false); - } else { - funcType = FunctionType::get(srcTy, {srcTy, srcTy}, false); - } - if (function == nullptr) { - function = Function::Create(funcType, Function::InternalLinkage, - function_name, M); - } - - BasicBlock *BB = BasicBlock::Create(I->getContext(), "entry", function); - Builder.SetInsertPoint(BB); - - Function::arg_iterator args = function->arg_begin(); - Value *sr_op = createUpDownOp(Builder, I, &*args, opCode, users); - errs() << "sr_op: " << *sr_op << "\n"; - Builder.CreateRet(sr_op); - Builder.ClearInsertionPoint(); - } - - // Check if this instruction is the last one in the block - if (I->isTerminator()) { - // Insert at the end of the block - Builder.SetInsertPoint(I->getParent()); - } else { - // Set insertion point after the given instruction - Builder.SetInsertPoint(I->getNextNode()); - } + const std::string &libname_prefix = libname; + const std::string &opname = Fops2str[opCode]; + const std::string &fpname = validTypesMap[baseType->getTypeID()]; - if (opCode == Fops::FOP_FMA) { - std::vector args = {I->getOperand(0), I->getOperand(1), - I->getOperand(2)}; - return Builder.CreateCall(function, args); - } else { - std::vector args = {I->getOperand(0), I->getOperand(1)}; - return Builder.CreateCall(function, args); - } + // TODO: add vector size to the function name + return libname + "_" + opname + "_" + fpname; } - Value *replaceArithmeticWithMCACall_UpOrDown(IRBuilder<> &Builder, - Instruction *I, - std::set &users) { + Function *getMCAFunction(Instruction *I) { Fops opCode = mustReplace(*I); - Value *value = getOrCreateUpDownFunction(Builder, I, opCode, users); - return value; - } - - Value *insertOriginalInstruction(IRBuilder<> &Builder, Instruction *I, - Fops opCode, Function::arg_iterator args) { - Value *op = nullptr; - Value *arg1 = nullptr, *arg2 = nullptr, *arg3 = nullptr; - switch (opCode) { - case Fops::FOP_ADD: - arg1 = &*args; - args++; - arg2 = &*args; - op = Builder.CreateFAdd(arg1, arg2); - break; - case Fops::FOP_SUB: - arg1 = &*args; - args++; - arg2 = &*args; - op = Builder.CreateFSub(arg1, arg2); - break; - case Fops::FOP_MUL: - arg1 = &*args; - args++; - arg2 = &*args; - op = Builder.CreateFMul(arg1, arg2); - break; - case Fops::FOP_DIV: - arg1 = &*args; - args++; - arg2 = &*args; - op = Builder.CreateFDiv(arg1, arg2); - break; - case Fops::FOP_FMA: - // arg1 = &*args; - // args++; - // arg2 = &*args; - // args++; - // arg3 = &*args; - op = CREATE_FMA_CALL(Builder, arg1->getType(), args); - break; - default: - errs() << "Unsupported opcode: " << opCode << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } - return op; - } - - // clang-format off - // define dso_local <4 x float> @ud_round_b32_4x(<4 x float> noundef %0) local_unnamed_addr #10 { - // entry: - // %2 = fcmp une <4 x float> %0, zeroinitializer - // %3 = bitcast <4 x i1> %2 to i4 - // %4 = icmp eq i4 %3, 0 - // br i1 %4, label %15, label %5 - // noise: ; preds = entry - // %6 = bitcast <4 x float> %0 to <4 x i32> - // %7 = tail call i32 @get_rand_uint32() - // %8 = insertelement <4 x i32> poison, i32 %7, i64 0 - // %9 = shufflevector <4 x i32> %8, <4 x i32> poison, <4 x i32> zeroinitializer - // %10 = and <4 x i32> %9, - // %11 = icmp eq <4 x i32> %10, zeroinitializer - // %12 = select <4 x i1> %11, <4 x i32> , <4 x i32> - // %13 = add <4 x i32> %12, %6 - // %14 = bitcast <4 x i32> %13 to <4 x float> - // br label %15 - // ret: ; preds = entry, noise - // %16 = phi <4 x float> [ %14, %5 ], [ %0, %1 ] - // ret <4 x float> %16 - // } - // clang-format on - Value *createUpDownOp(IRBuilder<> &Builder, Instruction *I, - Function::arg_iterator args, Fops opCode, - std::set &users) { - - Type *srcTy = nullptr; - bool isVector = I->getType()->isVectorTy(); - VECTOR_TYPE *VT = dyn_cast(I->getType()); - srcTy = (isVector) ? VT->getElementType() : I->getType(); - Type *fpAsIntTy = getFPAsIntType(srcTy); - - BasicBlock *entryBB = Builder.GetInsertBlock(); - BasicBlock *noiseBB = - BasicBlock::Create(I->getContext(), "noise", entryBB->getParent()); - BasicBlock *retBB = - BasicBlock::Create(I->getContext(), "ret", entryBB->getParent()); - - // entry block - Builder.SetInsertPoint(entryBB); - Instruction *op = static_cast( - insertOriginalInstruction(Builder, I, opCode, args)); - - uint32_t size = 1; - auto count = CREATE_VECTOR_ELEMENT_COUNT(size); - - Value *randomBits = nullptr, *zeroinitializerInt = nullptr, - *zeroinitializerInt64 = nullptr, *zeroinitializerFP = nullptr, - *cmp = nullptr, *bitcast = nullptr, *icmp = nullptr; - Type *bitCastTy = nullptr; - - if (isVector) { - // entry: - // %2 = fcmp une <4 x float> %0, zeroinitializer - // %3 = bitcast <4 x i1> %2 to i4 - // %4 = icmp eq i4 %3, 0 - // br i1 %4, label %15, label %5 - srcTy = VT->getElementType(); - size = VT->getNumElements(); - count = CREATE_VECTOR_ELEMENT_COUNT(size); - auto *vecIntTy = VECTOR_TYPE::get(fpAsIntTy, size); - zeroinitializerFP = ConstantAggregateZero::get(I->getType()); - zeroinitializerInt = ConstantAggregateZero::get(vecIntTy); - zeroinitializerInt64 = ConstantAggregateZero::get( - VECTOR_TYPE::get(Builder.getInt64Ty(), size)); - // %2 = fcmp une <4 x float> %0, zeroinitializer - cmp = Builder.CreateFCmpUNE(op, zeroinitializerFP, "is_zero"); - auto *intNTy = Type::getIntNTy(I->getContext(), VT->getNumElements()); - // %3 = bitcast <4 x i1> %2 to i4 - bitcast = Builder.CreateBitCast(cmp, intNTy); - errs() << "bitcast: " << *bitcast << "\n"; - errs() << "zeroinitializerInt: " << *zeroinitializerInt << "\n"; - // Generate zeronitializer for bitcast - auto *zero = ConstantInt::get(intNTy, 0); - // %4 = icmp eq i4 %3, 0 - errs() << "zero: " << *zero << "\n"; - icmp = Builder.CreateICmpEQ(bitcast, zero); - Builder.CreateCondBr(icmp, retBB, noiseBB); - bitCastTy = VECTOR_TYPE::get(fpAsIntTy, VT->getNumElements()); - } else { - srcTy = I->getType(); - zeroinitializerFP = ConstantFP::get(srcTy, 0.0); - zeroinitializerInt = ConstantInt::get(fpAsIntTy, 0); - zeroinitializerInt64 = ConstantInt::get(Builder.getInt64Ty(), 0); - Constant *zeroFP = ConstantFP::get(I->getType(), 0.0); - cmp = Builder.CreateFCmpOEQ(op, zeroFP, "is_zero"); - Builder.CreateCondBr(cmp, retBB, noiseBB); - bitCastTy = fpAsIntTy; - } - - // noise block - Builder.SetInsertPoint(noiseBB); - Value *fpAsInt = Builder.CreateBitCast(op, bitCastTy, "fpAsInt"); - users.insert(static_cast(fpAsInt)); - - // if (isVector and VfclibInstRNG == "xoroshiro") { - // insertVectorizeRandUint64Call(Builder, op, &randomBits); - // } else { - // } - insertRandUint64Call(Builder, op, &randomBits); - - if (isVector) { - // clang-format off - // % 8 = insertelement<4 x i32> poison, i32 % 7, i64 0 - // % 9 = shufflevector<4 x i32> % 8, <4 x i32> poison, <4 x i32> zeroinitializer - // % 10 = and<4 x i32> % 9, - // clang-format on - // if (srcTy->isFloatTy()) { - // randomBits = Builder.CreateBitCast(randomBits, fpAsIntTy, - // "randBits_to_fp"); - // } - auto *vecTy = VECTOR_TYPE::get(Builder.getInt64Ty(), size); - auto *poison = UndefValue::get(vecTy); - auto *zero = Builder.getInt64(0); - auto *insert = - Builder.CreateInsertElement(poison, randomBits, zero, "insertZero"); - auto *shuffle = - Builder.CreateShuffleVector(insert, poison, zeroinitializerInt); - std::vector pow2; - for (unsigned i = 0; i < VT->getNumElements(); i++) { - pow2.push_back(1 << i); - } - Module *M = Builder.GetInsertBlock()->getParent()->getParent(); - auto *cstPow2 = getVectorConstant(M, pow2); - randomBits = Builder.CreateAnd(shuffle, cstPow2); - } else { - // %6 = and i32 %5, 1 - randomBits = Builder.CreateAnd(randomBits, 1); - } - - errs() << "randomBits: " << *randomBits << "\n"; - errs() << "zeroinitializerInt64: " << *zeroinitializerInt64 << "\n"; - auto *zero = ConstantInt::get(randomBits->getType(), 0); - auto *icmp2 = Builder.CreateICmpEQ(randomBits, zero); - - Constant *one = ConstantInt::get(fpAsIntTy, 1); - Constant *mone = ConstantInt::get(fpAsIntTy, -1); - if (isVector) { - one = ConstantVector::getSplat(count, one); - mone = ConstantVector::getSplat(count, mone); - } - - auto *select = Builder.CreateSelect(icmp2, mone, one); - auto *add2 = Builder.CreateAdd(fpAsInt, select); - auto *fpNoised = Builder.CreateBitCast(add2, I->getType()); - Builder.CreateBr(retBB); - - // ret block - Builder.SetInsertPoint(retBB); - PHINode *phiNode = Builder.CreatePHI(I->getType(), 2); - phiNode->addIncoming(fpNoised, noiseBB); - phiNode->addIncoming(op, entryBB); - - return phiNode; + std::string functionName = getFunctionName(I, opCode); + return getFunction(I->getModule(), functionName); } /* Replace arithmetic instructions with MCA */ - Value *replaceArithmeticWithMCACall(IRBuilder<> &Builder, Instruction *I, - std::set &users) { - if (VfclibInstMode == "up-down") { - return replaceArithmeticWithMCACall_UpOrDown(Builder, I, users); - } else if (VfclibInstMode == "sr") { - return replaceArithmeticWithMCACall_SR(Builder, I, users); - } else { - errs() << "Unsupported mode: " << VfclibInstMode << "\n"; - report_fatal_error("libVFCInstrument fatal error"); - } + Value *replaceArithmeticWithMCACall(IRBuilder<> &Builder, Instruction *I) { + Function *F = getMCAFunction(I); + // get arguments of the instruction + std::vector operands(I->op_begin(), I->op_end()); + return Builder.CreateCall(F, operands); } - Value *replaceWithMCACall(Module &M, Instruction *I, Fops opCode, - std::set &users) { + Value *replaceWithMCACall(Module &M, Instruction *I, Fops opCode) { if (not isValidInstruction(I)) { return nullptr; } - IRBuilder<> Builder(I->getContext()); + IRBuilder<> Builder(I); + // Check if this instruction is the last one in the block if (I->isTerminator()) { // Insert at the end of the block @@ -2471,12 +578,12 @@ struct VfclibInst : public ModulePass { // We call directly a hardcoded helper function // no need to go through the vtable at this stage. - Value *newInst = replaceArithmeticWithMCACall(Builder, I, users); + Value *newInst = replaceArithmeticWithMCACall(Builder, I); return newInst; } - bool isFMAOperation(Instruction &I) { - CallInst *CI = static_cast(&I); + bool isFMAOperation(const Instruction &I) { + auto CI = static_cast(&I); if (CI->getCalledFunction() == nullptr) return false; const std::string &name = CI->getCalledFunction()->getName().str(); @@ -2491,7 +598,7 @@ struct VfclibInst : public ModulePass { return false; } - Fops mustReplace(Instruction &I) { + Fops getOpCode(const Instruction &I) { switch (I.getOpcode()) { case Instruction::FAdd: return FOP_ADD; @@ -2504,11 +611,8 @@ struct VfclibInst : public ModulePass { return FOP_DIV; case Instruction::Call: // Only instrument FMA if the flag --inst-fma is passed - if (VfclibInstInstrumentFMA and isFMAOperation(I)) { - return FOP_FMA; - } else { - return FOP_IGNORE; - } + return (VfclibInstInstrumentFMA and isFMAOperation(I)) ? FOP_FMA + : FOP_IGNORE; case Instruction::FCmp: return FOP_IGNORE; default: @@ -2516,84 +620,13 @@ struct VfclibInst : public ModulePass { } } - void replaceUsageWith(std::set &users, Value *from, Value *to) { - for (User *user : from->users()) { - if (Instruction *ii = dyn_cast(user)) { - if (users.find(ii) == users.end()) { - for (auto &op : ii->operands()) { - if (op.get() == from) { - errs() << "Replacing " << *from << " with " << *to << " in " - << *ii << '\n'; - op.set(to); - } - } - } - } - } - } - - void insertShishuaCall(Module &M, Instruction *I) { - auto *int32Ty = Type::getInt32Ty(M.getContext()); - auto *int64Ty = Type::getInt64Ty(M.getContext()); - auto *function_name = "_shishua_uint641"; - FunctionType *shishuType = FunctionType::get(int64Ty, {int32Ty}, false); - - Function *shishua = getOrInsertFunction(M, function_name, shishuType); - IRBuilder<> Builder(I->getContext()); - Builder.SetInsertPoint(I); - std::vector args = {Builder.getInt32(1)}; - auto *call = Builder.CreateCall(shishua, args); - I->replaceAllUsesWith(call); - } - - void replaceArgsWithShishuaInfo(Module &M, Instruction *I) { - // call void @print_buffer_(i32 noundef %1, i8* noundef getelementptr - // inbounds ([256 x i8], [256 x i8]* @buf, i64 0, i64 0)) - IRBuilder<> Builder(I->getContext()); - Builder.SetInsertPoint(I); - // get @shishua_buffer - GlobalVariable *rand_uint64_buffer = getGVRandUint64Buffer(M); - // get @shishua_buffer_size - GlobalVariable *rand_uint64_buffer_size = getGVRandUint64Counter(M); - auto *call = static_cast(I); - // replace the first argument with @shishua_buffer_size - // dereference the pointer - auto *int32Ty = Type::getInt32Ty(M.getContext()); - auto *load = Builder.CreateLoad(int32Ty, rand_uint64_buffer_size); - call->setArgOperand(0, load); - // replace the second argument with @shishua_buffer - // cast the pointer to uint8_t* - auto *int8PtrTy = Type::getInt8PtrTy(M.getContext()); - // i8* noundef getelementptr inbounds ([256 x i8], [256 x i8]* @buf, i64 0, - // i64 0) - auto *zero = Builder.getInt64(0); - auto *zero2 = Builder.getInt64(0); - std::vector gep_args = {zero, zero2}; - Type *type = rand_uint64_buffer->getType()->getPointerElementType(); - errs() << "type: " << *type << '\n'; - errs() << "rand_uint64_buffer: " << *rand_uint64_buffer << '\n'; - auto *gep = Builder.CreateGEP(type, rand_uint64_buffer, gep_args); - call->setArgOperand(1, gep); - } + Fops mustReplace(Instruction &I) { return getOpCode(I); } bool runOnBasicBlock(Module &M, BasicBlock &B) { bool modified = false; std::set> WorkList; - for (BasicBlock::iterator ii = B.begin(), ie = B.end(); ii != ie; ++ii) { - Instruction &I = *ii; - - if (CallInst *CI = dyn_cast(&I)) { - Function *F = CI->getCalledFunction(); - if (F->getName() == "CALL_SHISHUA") { - insertShishuaCall(M, CI); - continue; - } else if (F->getName() == "print_shishua_info") { - replaceArgsWithShishuaInfo(M, CI); - continue; - } - } - - Fops opCode = mustReplace(I); + for (auto &I : B) { + auto opCode = mustReplace(I); if (opCode == FOP_IGNORE) continue; WorkList.insert(std::make_pair(&I, opCode)); @@ -2604,25 +637,15 @@ struct VfclibInst : public ModulePass { Fops opCode = p.second; if (VfclibInstVerbose) errs() << "Instrumenting" << *I << '\n'; - std::set fp_users; - Value *value = replaceWithMCACall(M, I, opCode, fp_users); + Value *value = replaceWithMCACall(M, I, opCode); if (value != nullptr) { - I->replaceAllUsesWith(value); - I->eraseFromParent(); - modified = true; + BasicBlock::iterator ii(I); +#if LLVM_VERSION_MAJOR >= 16 + ReplaceInstWithValue(ii, value); +#else + ReplaceInstWithValue(B.getInstList(), ii, value); +#endif } - // if (value != nullptr and VfclibInstMode != "up-down") { - // I->replaceAllUsesWith(value); - // I->eraseFromParent(); - // modified = true; - // } else if (value != nullptr and VfclibInstMode == "up-down") { - // // We need to replace the original instruction with the noised - // // instruction for all the users of the original instruction after - // the - // // noise is added - // replaceUsageWith(fp_users, I, value); - // modified = true; - // } } return modified; diff --git a/src/libvfcinstrumentonline/rand.cpp b/src/libvfcinstrumentonline/rand.cpp new file mode 100644 index 00000000..66a6eb84 --- /dev/null +++ b/src/libvfcinstrumentonline/rand.cpp @@ -0,0 +1,334 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include other necessary headers (float_const.h, float_struct.h, shishua.h, +// vector_types.h) + +// force fma hardware instruction +// fail if not supported + +#ifdef DEBUG +#include +#define debug_print(fmt, ...) fprintf(stderr, fmt, __VA_ARGS__); +#else +#define debug_print(fmt, ...) +#endif + +#ifdef USE_BOOST_TLS +#define BOOST_THREAD_USE_LIB +#include +#ifdef XOROSHIRO +using xoroshiro_state = std::array; +boost::thread_specific_ptr rng_state; +#elif defined(SHISHUA) +boost::thread_specific_ptr rng_state; +boost::thread_specific_ptr shishua_buffer_index; +boost::thread_specific_ptr> buf; +#else +#error "No PRNG defined" +#endif +#else +#ifdef XOROSHIRO +#include "xoroshiro128+.cpp" +typedef __INTERNAL_RNG_STATE rng_state_t; +static __thread rng_state_t rng_state; +#elif defined(SHISHUA) +#include "shishua.h" +typedef prng_state rng_state_t; +static __thread rng_state_t rng_state; +#else +#error "No PRNG defined" +#endif +#endif + +static pid_t global_tid = 0; +static bool already_initialized = false; + +uint64_t next_seed(uint64_t seed_state) { + uint64_t z = (seed_state += UINT64_C(0x9E3779B97F4A7C15)); + z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9); + z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB); + return z ^ (z >> 31); +} + +#define rotl(x, k) ((x << k) | (x >> (64 - k))) + +uint64_t _get_rand_uint64() { +#ifdef XOROSHIRO +#ifdef USE_BOOST_TLS + auto state = rng_state.get(); + if (!state) { + state = new xoroshiro_state(); + rng_state.reset(state); + } +#else +#endif + __m256i result = xoroshiro128plus_avx2_next(&rng_state); + return _mm256_extract_epi32(result, 0); +#elif defined(SHISHUA) +#ifdef USE_BOOST_TLS + return prng_uint64(rng_state.get()); +#else + return prng_uint64(&rng_state); +#endif +#else +#error "No PRNG defined" +#endif +} + +#ifdef XOROSHIRO +template T get_rand_uint() { + return static_cast(_get_rand_uint64()); +} +#elif defined(SHISHUA) +template T get_rand_uint() { + return static_cast(prng_uint64(&rng_state)); +} +#endif + +uint32_t get_rand_uint32_t() { return get_rand_uint(); } +uint64_t get_rand_uint64_t() { return get_rand_uint(); } + +template +typename std::enable_if::value, int64_t>::type +get_rand_float(T a) { + const uint64_t half_max_int = UINT64_MAX / 2; + if (a == 0) + return 0; + return (_get_rand_uint64() < half_max_int) ? 1 : -1; +} + +// TODO: check if this is correct +float get_rand_float01() { + uint32_t x = get_rand_uint32_t(); + const union { + uint32_t i; + float f; + } u = {.i = UINT32_C(0x3F800000) | (x >> 9)}; + return u.f - 1.0f; +} + +double get_rand_double01() { + uint64_t x = get_rand_uint64_t(); + const union { + uint64_t i; + double d; + } u = {.i = UINT64_C(0x3FF) << 52 | x >> 12}; + return u.d - 1.0; +} + +void init() { + if (already_initialized) + return; + already_initialized = true; + + uint64_t seed = 1; + struct timeval t1; + gettimeofday(&t1, nullptr); + seed = t1.tv_sec ^ t1.tv_usec ^ syscall(__NR_gettid); + +#ifdef XOROSHIRO + xoroshiro128plus_avx2_init(&rng_state, seed); +#elif defined(SHISHUA) + uint64_t seed_state[4] = {next_seed(seed), next_seed(seed), next_seed(seed), + next_seed(seed)}; + prng_init(&rng_state, seed_state); +#else +#error "No PRNG defined" +#endif +} + +// Implement other functions (get_exponent, predecessor, abs, pow2, etc.) using +// templates +template T predecessor(T a) { + if (a == 0) + return -std::numeric_limits::denorm_min(); + // return previous floating-point number + if constexpr (std::is_same::value) { + uint32_t i = *reinterpret_cast(&a); + i -= 1; + return *reinterpret_cast(&i); + } else { + uint64_t i = *reinterpret_cast(&a); + i -= 1; + return *reinterpret_cast(&i); + } +} + +template uint32_t get_exponent(T a) { + // return exponent of a + if constexpr (std::is_same::value) { + return (reinterpret_cast(a) >> 23) & 0xFF; + } else { + return (reinterpret_cast(a) >> 52) & 0x7FF; + } +} + +template T pow2(I n) { + if constexpr (std::is_same::value) { + float x = 1.0f; + uint32_t i = *reinterpret_cast(&x); + i += n << 23; + x = *reinterpret_cast(&i); + return x; + } else { + double x = 1.0; + uint64_t i = *reinterpret_cast(&x); + i += static_cast(n) << 52; + x = *reinterpret_cast(&i); + return x; + } +} + +template T sr_round(T sigma, T tau, T z) { + T eps = std::is_same::value ? 0x1.0p-23f : 0x1.0p-52; + bool sign_tau = tau < 0; + bool sign_sigma = sigma < 0; + uint32_t eta; + if (sign_tau != sign_sigma) { + eta = get_exponent(predecessor(std::abs(sigma))); + } else { + eta = get_exponent(sigma); + } + T ulp = (sign_tau ? 1 : -1) * pow2(eta) * eps; + T pi = ulp * z; + T round; + if (std::abs(tau + pi) >= std::abs(ulp)) { + round = ulp; + } else { + round = 0; + } + return round; +} + +template T ud_round(T a) { + if (a == 0) + return a; + using IntType = + typename std::conditional::type; + IntType a_bits; + std::memcpy(&a_bits, &a, sizeof(T)); + uint64_t rand = _get_rand_uint64(); + a_bits += (rand & 0x01) ? 1 : -1; + std::memcpy(&a, &a_bits, sizeof(T)); + return a; +} + +template T ud_add(T a, T b) { return ud_round(a + b); } +template T ud_sub(T a, T b) { return ud_add(a, -b); } +template T ud_mul(T a, T b) { return ud_round(a * b); } +template T ud_div(T a, T b) { return ud_round(a / b); } + +// Implement vector versions of ud_round (ud_round_b32_2x, ud_round_b32_4x, +// etc.) + +template void twosum(T a, T b, T *tau, T *sigma) { + *sigma = a + b; + T z = *sigma - a; + *tau = (a - (*sigma - z)) + (b - z); +} + +template +__attribute__((target("fma"))) void twoprodfma(T a, T b, T *tau, T *sigma) { + *sigma = a * b; + *tau = std::fma(a, b, -(*sigma)); +} + +template T sr_add(T a, T b) { + T z = get_rand_double01(); + T tau, sigma, round; + twosum(a, b, &tau, &sigma); + round = sr_round(sigma, tau, z); + return sigma + round; +} + +template T sr_sub(T a, T b) { return sr_add(a, -b); } + +template T sr_mul(T a, T b) { + T z = get_rand_double01(); + T tau, sigma, round; + twoprodfma(a, b, &tau, &sigma); + round = sr_round(sigma, tau, z); + return sigma + round; +} + +template T add_round_odd(T a, T b) { + // return addition with rounding to odd + // https://www.lri.fr/~melquion/doc/08-tc.pdf + T x, e; + twosum(a, b, &x, &e); + return (e == 0 || *reinterpret_cast(&x) & 1) ? x : x + 1; +} + +template T __attribute__((target("fma"))) sr_div(T a, T b) { + T z = get_rand_double01(); + T sigma = a / b; + T tau = std::fma(-sigma, b, a) / b; + T round = sr_round(sigma, tau, z); + return sigma + round; +} + +template T __attribute__((target("fma,sse2"))) sr_sqrt(T a) { + T z = get_rand_double01(); + T sigma; +#if defined(__SSE2__) + if constexpr (std::is_same::value) { + // call assembly sqrtss + asm("sqrtss %1, %0" + : "=x"(sigma) // Output operand + : "x"(a) // Input operand + ); + } else { + // call assembly sqrtsd + asm("sqrtsd %1, %0" + : "=x"(sigma) // Output operand + : "x"(a) // Input operand + ); + } +#else + T sigma = std::sqrt(a); +#endif + T tau = std::fma(-sigma, sigma, a) / (2 * sigma); + T round = sr_round(sigma, tau, z); + return sigma + round; +} + +// specializations for float and double +float sr_add_float(float a, float b) { return sr_add(a, b); } +float sr_sub_float(float a, float b) { return sr_sub(a, b); } +float sr_mul_float(float a, float b) { return sr_mul(a, b); } +float sr_div_float(float a, float b) { return sr_div(a, b); } +float sr_sqrt_float(float a) { return sr_sqrt(a); } + +double sr_add_double(double a, double b) { return sr_add(a, b); } +double sr_sub_double(double a, double b) { return sr_sub(a, b); } +double sr_mul_double(double a, double b) { return sr_mul(a, b); } +double sr_div_double(double a, double b) { return sr_div(a, b); } +double sr_sqrt_double(double a) { return sr_sqrt(a); } + +float ud_add_float(float a, float b) { return ud_add(a, b); } +float ud_sub_float(float a, float b) { return ud_sub(a, b); } +float ud_mul_float(float a, float b) { return ud_mul(a, b); } +float ud_div_float(float a, float b) { return ud_div(a, b); } + +double ud_add_double(double a, double b) { return ud_add(a, b); } +double ud_sub_double(double a, double b) { return ud_sub(a, b); } +double ud_mul_double(double a, double b) { return ud_mul(a, b); } +double ud_div_double(double a, double b) { return ud_div(a, b); } + +// Use a global object to ensure initialization +struct Initializer { + Initializer() { init(); } +}; + +Initializer initializer; \ No newline at end of file diff --git a/src/libvfcinstrumentonline/shishua.h b/src/libvfcinstrumentonline/shishua.h index c327c737..2e41c1a1 100644 --- a/src/libvfcinstrumentonline/shishua.h +++ b/src/libvfcinstrumentonline/shishua.h @@ -12,7 +12,9 @@ typedef struct prng_state { __m256i counter; } prng_state; -static inline void prng_gen(struct prng_state *s, uint8_t *buf, size_t size) { +uint8_t prng_buf[1024] __attribute__((aligned(32))); + +static inline void prng_gen(struct prng_state *s, uint8_t *buf) { __m256i o0 = s->output[0], o1 = s->output[1], o2 = s->output[2], o3 = s->output[3]; __m256i s0 = s->state[0], s1 = s->state[1], s2 = s->state[2], @@ -24,7 +26,7 @@ static inline void prng_gen(struct prng_state *s, uint8_t *buf, size_t size) { const __m256i increment = _mm256_set_epi64x(1, 3, 5, 7); size_t i = 0; - for (; i + 128 <= size; i += 128) { + for (; i + 128 <= 1024; i += 128) { _mm256_store_si256((__m256i *)&buf[i], o0); _mm256_store_si256((__m256i *)&buf[i + 32], o1); _mm256_store_si256((__m256i *)&buf[i + 64], o2); @@ -65,16 +67,23 @@ static inline void prng_gen(struct prng_state *s, uint8_t *buf, size_t size) { s->state[3] = s3; s->counter = counter; - if (i < size) { + if (i < 1024) { uint8_t temp[128] __attribute__((aligned(32))); _mm256_store_si256((__m256i *)&temp[0], o0); _mm256_store_si256((__m256i *)&temp[32], o1); _mm256_store_si256((__m256i *)&temp[64], o2); _mm256_store_si256((__m256i *)&temp[96], o3); - memcpy(buf + i, temp, size - i); + memcpy(buf + i, temp, 1024 - i); } } +uint64_t prng_uint64(struct prng_state *s) { + uint64_t result; + prng_gen(s, prng_buf); + memcpy(&result, prng_buf, sizeof(result)); + return result; +} + void prng_init(struct prng_state *s, const uint64_t seed[4]) { static const uint64_t phi[16] __attribute__((aligned(32))) = { 0x9E3779B97F4A7C15, 0xF39CC0605CEDC834, 0x1082276BF3A27251, @@ -96,9 +105,8 @@ void prng_init(struct prng_state *s, const uint64_t seed[4]) { s->counter = _mm256_setzero_si256(); - uint8_t buf[1024] __attribute__((aligned(32))); for (int i = 0; i < 13; ++i) { - prng_gen(s, buf, sizeof(buf)); + prng_gen(s, prng_buf); s->state[0] = s->output[3]; s->state[1] = s->output[2]; s->state[2] = s->output[1]; diff --git a/src/libvfcinstrumentonline/xoroshiro128+.cpp b/src/libvfcinstrumentonline/xoroshiro128+.cpp new file mode 100644 index 00000000..8caa4d09 --- /dev/null +++ b/src/libvfcinstrumentonline/xoroshiro128+.cpp @@ -0,0 +1,104 @@ +#include +#include +#include +#include +#include + +#define STATE_SIZE 2 + +typedef struct { + __m256i s[STATE_SIZE]; +} xoroshiro128plus_avx2_state; + +#define __INTERNAL_RNG_STATE xoroshiro128plus_avx2_state + +// Function to rotate left (vectorized) +static inline __m256i rotl_avx2(__m256i x, int k) { + return _mm256_or_si256(_mm256_slli_epi64(x, k), _mm256_srli_epi64(x, 64 - k)); +} + +// XOROSHIRO128++ next function (vectorized) +__m256i xoroshiro128plus_avx2_next(xoroshiro128plus_avx2_state *state) { + const __m256i s0 = state->s[0]; + __m256i s1 = state->s[1]; + const __m256i result = + _mm256_add_epi64(rotl_avx2(_mm256_add_epi64(s0, s1), 17), s0); + + s1 = _mm256_xor_si256(s1, s0); + state->s[0] = _mm256_xor_si256(_mm256_xor_si256(rotl_avx2(s0, 49), s1), + _mm256_slli_epi64(s1, 21)); + state->s[1] = rotl_avx2(s1, 28); + + return result; +} + +// Function to initialize the state +void xoroshiro128plus_avx2_init(xoroshiro128plus_avx2_state *state, + uint64_t seed) { + uint64_t temp_state[4][2]; + for (int i = 0; i < 4; i++) { + temp_state[i][0] = seed + i; + temp_state[i][1] = (seed + i) ^ 0x1234567890abcdefULL; + } + + state->s[0] = _mm256_loadu_si256((__m256i *)&temp_state[0][0]); + state->s[1] = _mm256_loadu_si256((__m256i *)&temp_state[0][1]); + + // Warm up the state + for (int i = 0; i < 100; i++) { + xoroshiro128plus_avx2_next(state); + } +} + +// Generate random numbers using AVX2 +void generate_random_numbers_avx2(uint64_t *output, size_t count) { + xoroshiro128plus_avx2_state state; + xoroshiro128plus_avx2_init(&state, 12345); + + size_t i = 0; + for (; i + 4 <= count; i += 4) { + __m256i result = xoroshiro128plus_avx2_next(&state); + _mm256_storeu_si256((__m256i *)&output[i], result); + } + + // Handle any remaining elements + if (i < count) { + __m256i result = xoroshiro128plus_avx2_next(&state); + uint64_t temp[4]; + _mm256_storeu_si256((__m256i *)temp, result); + memcpy(&output[i], temp, (count - i) * sizeof(uint64_t)); + } +} + +#ifdef XOROSHIRO_TEST +int main(int argc, char *argv[]) { + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + // read count from input + const size_t count = atoi(argv[1]); + + uint64_t *random_numbers = + (uint64_t *)aligned_alloc(32, count * sizeof(uint64_t)); + + if (!random_numbers) { + fprintf(stderr, "Memory allocation failed\n"); + return 1; + } + + generate_random_numbers_avx2(random_numbers, count); + + printf("Generated %zu random numbers\n", count); + + // Print first few numbers as a sample + for (int i = 0; i < 10; i++) { + printf("%lu ", random_numbers[i]); + } + printf("\n"); + + free(random_numbers); + return 0; +} +#endif \ No newline at end of file diff --git a/src/libvfcinstrumentonline/xoroshiro128+.hpp b/src/libvfcinstrumentonline/xoroshiro128+.hpp new file mode 100644 index 00000000..8c4b617b --- /dev/null +++ b/src/libvfcinstrumentonline/xoroshiro128+.hpp @@ -0,0 +1,17 @@ +#include +#include +#include +#include +#include + +#define STATE_SIZE 2 + +typedef struct { + __m256i s[STATE_SIZE]; +} xoroshiro128plus_avx2_state; + +#define __INTERNAL_RNG_STATE xoroshiro128plus_avx2_state + +void xoroshiro128plus_avx2_init(xoroshiro128plus_avx2_state *state, + uint64_t seed); +__m256i xoroshiro128plus_avx2_next(xoroshiro128plus_avx2_state *state); \ No newline at end of file diff --git a/tests/test_online_instrumentation/test.c b/tests/test_online_instrumentation/test.c index bc947463..6d60e062 100644 --- a/tests/test_online_instrumentation/test.c +++ b/tests/test_online_instrumentation/test.c @@ -61,32 +61,8 @@ float fma_float(float a, float b, float c) { return __builtin_fmaf(a, b, c); } abort(); \ } -__attribute__((noinline)) REAL operator(char op, REAL a, REAL b, - REAL c){OPERATOR(a, b, c)} - -int64_t CALL_SHISHUA(int32_t) { - return 0; -} - -void print_shishua_info(uint32_t shishua_buffer_index, uint8_t *buf) { - printf("Shishua buffer: \n"); - printf("Buffer index: %d\n", shishua_buffer_index); - for (int i = 0; i < 256; i++) { - if (i != 0 && i % 32 == 0) - printf("\n"); - printf("%02x ", buf[i]); - } - printf("\n"); -} - -void shishua_test() { - for (int i = 0; i < 512; i++) { - print_shishua_info(0, NULL); - if (i != 0 && i % 32 == 0) - printf("\n"); - printf("SHISHUA: %ld\n", CALL_SHISHUA(0)); - } - printf("\n"); +__attribute__((noinline)) REAL operator(char op, REAL a, REAL b, REAL c) { + OPERATOR(a, b, c) } int main(int argc, const char *argv[]) { @@ -111,7 +87,5 @@ int main(int argc, const char *argv[]) { printf("%.13a\n", operator(op, a, b, c)); - shishua_test(); - return EXIT_SUCCESS; } diff --git a/verificarlo.in.in b/verificarlo.in.in index d049efe4..44b6a226 100644 --- a/verificarlo.in.in +++ b/verificarlo.in.in @@ -237,15 +237,27 @@ def apply_function_instrumentation_pass(ir, ins, args): verbose=args.verbose, ) +def get_vfclibinst_online_ll(rng, verbose): + # TODO: cache the generated file + vfclibinst_online_c = f"{mcalib_includes}/rand.cpp" + vfclibinst_online_ll = "/tmp/rand.ll" + rng = rng.upper() + shell( + f"{clangxx} -std=c++17 -mavx2 -O3 -c -S -emit-llvm -D{rng} {vfclibinst_online_c} -o {vfclibinst_online_ll}", + verbose=verbose, + ) + return vfclibinst_online_ll # Apply MCA instrumentation pass def apply_mca_instrumentation_pass( ir, ins, vfcwrapper_ir, extra_args, selectfunction, args ): if args.online_instrumentation: + vfclibinst_online_ll = get_vfclibinst_online_ll(args.online_instrumentation_rng, args.verbose) libvfcinst = libvfcinstrumentonline extra_args += f" -vfclibinst-mode {args.online_instrumentation} " extra_args += f" -vfclibinst-rng {args.online_instrumentation_rng} " + extra_args += f" -vfclibinst-sr-file {vfclibinst_online_ll} " if args.online_instrumentation_seed: extra_args += f" -vfclibinst-seed {args.online_instrumentation_seed} " @@ -423,7 +435,7 @@ def parse_args(): parser.add_argument( "--online-instrumentation", action="store", - choices=["up-down", "sr"], + choices=["up-down", "sr", "dummy"], type=str, help="Online instrumentation. (mode: up-down, sr)", )