Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA benchmarks might be misleading #47

Open
zimonitrome opened this issue Feb 19, 2022 · 1 comment
Open

CUDA benchmarks might be misleading #47

zimonitrome opened this issue Feb 19, 2022 · 1 comment

Comments

@zimonitrome
Copy link

I wanted to try to improve/modify the torchsort code a little so I tried making a copy of the SoftSort class and the soft_sort function.

Running some benchmarks I got the following results:
benchmark_custom
benchmark_custom_cuda

Which was worrying. The carbon copy diverges at a similar point to the figure in the readme:

I then re-ran the benchmark with the exact same function twice (not even a copy) and got the same results.

That code can be found here:

import sys
from collections import defaultdict
from timeit import timeit

import matplotlib.pyplot as plt
import torch

import torchsort

try:
    import fast_soft_sort.pytorch_ops as fss
except ImportError:
    print("install fast_soft_sort:")
    print("pip install git+https://github.com/google-research/fast-soft-sort")
    sys.exit()


N = list(range(1, 5_000, 100))
B = [2 ** i for i in range(9)]
B_CUDA = [2 ** i for i in range(13)]
SAMPLES = 100
CONVERT = 1e-6  # convert seconds to micro-seconds


def time(f):
    return timeit(f, number=SAMPLES) / SAMPLES / CONVERT


def backward(f, x):
    y = f(x)
    torch.autograd.grad(y.sum(), x)


def style(name):
    if name == "torch.sort":
        return {"color": "blue"}
    linestyle = "--" if "backward" in name else "-"
    if "fast_soft_sort" in name:
        return {"color": "green", "linestyle": linestyle}
    elif "again" in name:
        return {"color": "red", "linestyle": linestyle}
    else:
        return {"color": "orange", "linestyle": linestyle}


def batch_size(ax):
    data = defaultdict(list)
    for b in B:
        x = torch.randn(b, 100)
        # data["torch.sort"].append(time(lambda: torch.sort(x)))
        data["torchsort"].append(time(lambda: torchsort.soft_sort(x)))
        data["torchsort_again"].append(time(lambda: torchsort.soft_sort(x)))
        # data["fast_soft_sort"].append(time(lambda: fss.soft_sort(x)))
        x = torch.randn(b, 100, requires_grad=True)
        data["torchsort (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
        data["torchsort_again (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
        # data["fast_soft_sort (with backward)"].append(
        #     time(lambda: backward(fss.soft_sort, x))
        # )

    for label in data.keys():
        ax.plot(B, data[label], label=label, **style(label))
    ax.set_xlabel("Batch Size")
    ax.set_ylim(0, 5000)
    ax.set_ylabel("Execution Time (μs)")
    ax.legend()


def sequence_length(ax):
    data = defaultdict(list)
    for n in N:
        x = torch.randn(1, n)
        # data["torch.sort"].append(time(lambda: torch.sort(x)))
        data["torchsort"].append(time(lambda: torchsort.soft_sort(x)))
        data["torchsort_again"].append(time(lambda: torchsort.soft_sort(x)))
        # data["fast_soft_sort"].append(time(lambda: fss.soft_sort(x)))
        x = torch.randn(1, n, requires_grad=True)
        data["torchsort (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
        data["torchsort_again (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
        # data["fast_soft_sort (with backward)"].append(
        #     time(lambda: backward(fss.soft_sort, x))
        # )

    for label in data.keys():
        ax.plot(N, data[label], label=label, **style(label))
    ax.set_xlabel("Sequence Length")
    ax.set_ylim(0, 1000)
    ax.set_ylabel("Execution Time (μs)")
    ax.legend()


def batch_size_cuda(ax):
    data = defaultdict(list)
    for b in B_CUDA:
        x = torch.randn(b, 100).cuda()
        # data["torch.sort"].append(time(lambda: torch.sort(x)))
        data["torchsort"].append(time(lambda: torchsort.soft_sort(x)))
        data["torchsort_again"].append(time(lambda: torchsort.soft_sort(x)))
        x = torch.randn(b, 100, requires_grad=True).cuda()
        data["torchsort (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
        data["torchsort_again (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
    for label in data.keys():
        ax.plot(B_CUDA, data[label], label=label, **style(label))
    ax.set_xlabel("Batch Size")
    ax.set_ylabel("Execution Time (μs)")
    ax.legend()


def sequence_length_cuda(ax):
    data = defaultdict(list)
    for n in N:
        x = torch.randn(1, n).cuda()
        # data["torch.sort"].append(time(lambda: torch.sort(x)))
        data["torchsort"].append(time(lambda: torchsort.soft_sort(x)))
        data["torchsort_again"].append(time(lambda: torchsort.soft_sort(x)))
        x = torch.randn(1, n, requires_grad=True).cuda()
        data["torchsort (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
        data["torchsort_again (with backward)"].append(
            time(lambda: backward(torchsort.soft_sort, x))
        )
    for label in data.keys():
        ax.plot(N, data[label], label=label, **style(label))
    ax.set_xlabel("Sequence Length")
    ax.set_ylabel("Execution Time (μs)")
    ax.legend()


if __name__ == "__main__":
    # jit/warmup
    x = torch.randn(1, 10, requires_grad=True)
    backward(torchsort.soft_sort, x)
    backward(fss.soft_sort, x)

    fig, (ax1, ax2) = plt.subplots(figsize=(10, 4), ncols=2)
    sequence_length(ax1)
    batch_size(ax2)
    fig.suptitle("Torchsort Benchmark: CPU")
    fig.tight_layout()
    plt.savefig("extra/benchmark3.png")

    if torch.cuda.is_available():
        # warmup
        x = torch.randn(1, 10, requires_grad=True).cuda()
        backward(torchsort.soft_sort, x)

        fig, (ax1, ax2) = plt.subplots(figsize=(10, 4), ncols=2)
        sequence_length_cuda(ax1)
        batch_size_cuda(ax2)
        fig.suptitle("Torchsort Benchmark: CUDA")
        fig.tight_layout()
        plt.savefig("extra/benchmark_cuda3.png")

Any idea what this might depend on?

@teddykoker
Copy link
Owner

This is an interesting find! I have tried this now and am seeing a similar phenomenon:
benchmark_cuda3

Running torchsort_2, _3, _4 (all just calling the exact same function repeatedly) seem to have very similar performance, but it is interestingly worse than the initial trial. I suppose a more substantial warmup might be able to remedy this, but I am not quite sure why the performance would be worse - I'll keep looking into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants