Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing related_post_gen benchmark #838

Open
zigzag312 opened this issue Nov 2, 2023 · 12 comments
Open

Optimizing related_post_gen benchmark #838

zigzag312 opened this issue Nov 2, 2023 · 12 comments

Comments

@zigzag312
Copy link

Hi,

I added daScript to the related_post_gen benchmark.

https://github.com/jinyus/related_post_gen/blob/main/dascript/related.das

This is my first time using daScript, so my implementation is probably very far from optimal.

I couldn't find a way to avoid cloning the posts without using the unsafe. This seems to be causing OOM on the test VM. Is there a way to copy a reference or pointer to struct without using unsafe?

@AntonYudintsev
Copy link
Collaborator

Why not using post index instead? safe and fast

@zigzag312
Copy link
Author

I think no implementation does this as it wouldn't produce required result. allRelatedPosts when serialized needs to contain top 5 related posts (not indexes of posts). Moving index lookup to serialization phase would move it outside benchmark's processing time calculation.

@borisbat
Copy link
Collaborator

borisbat commented Nov 3, 2023

options skip_lock_checks

require daslib/json
require fio
require daslib/json_boost
require strings

struct Post
    _id: string
    title: string
    tags: array<string>

struct RelatedPosts
    _id: string
    tags: array<string>
    related: array<Post?>

let topN = 5

[export, unsafe_deref]
def main

    var file = fopen("../posts.json", "r")
    var jsonStr = file |> fread
    var error : string
    var json = read_json(jsonStr, error)
    if error != ""
        print("Error: {error}\n")
        return
    var tposts <- from_JV(json, type<array<Post>>)
    var posts <- [{for tpost in tposts; new [[Post _id = tpost._id, title = tpost.title, tags <- tpost.tags]]}]
    delete tposts

    var start_time = ref_time_ticks()

    let postsCount = length(posts)
    var tagMap : table<string; array<int>>

    for post, i in posts, range(0, postsCount)
        for tag in post.tags
            tagMap[tag] |> push(i)

    var allRelatedPosts : array<RelatedPosts>
    allRelatedPosts |> resize(postsCount)
    var taggedPostCount : array<int>
    taggedPostCount |> reserve(postsCount)
    var top5 : array<tuple<count:int; post_id:int>>
    top5 |> reserve(topN)

    for post, i in posts, range(0, postsCount)
        // reset counts
        taggedPostCount |> clear()
        taggedPostCount |> resize(postsCount)

        // Count the number of tags shared between posts
        for tag in post.tags
            for otherPostIdx in tagMap[tag]
                taggedPostCount[otherPostIdx]++
        taggedPostCount[i] = 0 // Don't count self

        // clear top5
        top5 |> clear()
        top5 |> resize(topN)

        var minTags = 0

        //  custom priority queue to find top N
        for count, j in taggedPostCount, range(0, postsCount)
            if count > minTags
                var upperBound = 3

                while upperBound >= 0 && count > top5[upperBound].count
                    top5[upperBound + 1] = top5[upperBound--]

                top5[upperBound + 1] = [[auto count, j]]
                minTags = top5[topN - 1].count


        // Convert indexes back to Post pointers.
        var topPosts : array<Post?> <- [{for j in range(0,topN); posts[top5[j].post_id]}]
        allRelatedPosts[i] <- [[RelatedPosts _id = posts[i]._id, tags := posts[i].tags, related <- topPosts]]

    var processing_time_ms = float(get_time_usec(start_time)) / 1000.0

    print("Processing time (w/o IO): {format("%.2f", processing_time_ms)} ms\n")

    var jsonResultStr = JV(allRelatedPosts) |> write_json()

    fopen("../related_posts_dascript.json", "w") |> fwrite(jsonResultStr)

@borisbat
Copy link
Collaborator

borisbat commented Nov 3, 2023

what is done above is that JSON is loaded into temp array, which is then moved to heap. that way we get pointers to individual posts and those are safe. then Post? can be stored just about anywhere.

i've cleaned up few obvious things. the biggest is array initialization. clear + resize in daScript keeps fields initialized to 0\null\empty string etc. i've also moved topPosts, instead of clonning them.

both options can be removed (skip_lock_checks and unsafe_deref). they contribute somewhat, and data can be reorganized to avoid both.

This is what I get locally:

INTERPRETED:
Processing time (w/o IO): 547.80 ms
AOT:
Processing time (w/o IO): 33.04 ms

@zigzag312
Copy link
Author

@borisbat That gives very good performance boost on my machine.

Rules say no unsafe code blocks, so unsafe_deref and skip_lock_checks probably need to be removed.

How can I compile AOT? Currently only interpreted mode is included in the benchmark.

@borisbat
Copy link
Collaborator

borisbat commented Nov 3, 2023

There is tutorial02aot in the examples on how to setup CMAKE project with AOT.

@zigzag312
Copy link
Author

zigzag312 commented Nov 3, 2023

I'm looking in run.sh how C++ is build in the benchmark. It builds single file directly like this:

g++ -O3 -std=c++20  -I./include main.cpp -o main

So, something like should work, right?

das -aot related.das related_aot.cpp
g++ -O3 -std=c++20  -I./include related_aot.cpp -o related_aot

@borisbat
Copy link
Collaborator

borisbat commented Nov 3, 2023

U'll need some sort of main.cpp from das, on top of related_aot.cpp. the aot generates custom daScript executable, which is used to run it. CodeOfPolicies aot field should be set to true.

@zigzag312
Copy link
Author

I'm not sure I understand what I need to do. Is there any chance you could create a PR for building AOT version?

@borisbat
Copy link
Collaborator

borisbat commented Nov 3, 2023

Okie. Once I have a stopping point I'll make a PR.

@AntonYudintsev
Copy link
Collaborator

How that “no unsafe” rules applies to c++ then:)?
CPP code uses direct (unsafe) references/pointers to Posts.

@zigzag312
Copy link
Author

Good question. I don't know really. C++ implementation should probably use shared_ptr for Posts to be by the rules. Best to raise an issue in the benchmark repo to get a clarification about this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants