Skip to content

Commit

Permalink
BallTree WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tzaeschke committed Aug 11, 2024
1 parent 7984b73 commit 2be2176
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 428 deletions.
16 changes: 14 additions & 2 deletions src/main/java/org/tinspin/index/PointMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.tinspin.index;

import org.tinspin.index.array.PointArray;
import org.tinspin.index.balltree.BallTree;
import org.tinspin.index.covertree.CoverTree;
import org.tinspin.index.kdtree.KDTree;
import org.tinspin.index.phtree.PHTreeP;
Expand Down Expand Up @@ -130,11 +131,22 @@ static <T> PointMap<T> createArray(int dims, int size) {
}

/**
* Create a COverTree.
* Create a BallTree.
*
* @param dims Number of dimensions.
* @param <T> Value type
* @return New PH-Tree
* @return New BallTree
*/
static <T> PointMap<T> createBallTree(int dims) {
return BallTree.create(dims);
}

/**
* Create a CoverTree.
*
* @param dims Number of dimensions.
* @param <T> Value type
* @return New CoverTree
*/
static <T> PointMap<T> createCoverTree(int dims) {
return CoverTree.create(dims);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2017 Tilmann Zaeschke
* Copyright 2016-2024 Tilmann Zaeschke
*
* This file is part of TinSpin.
*
Expand Down Expand Up @@ -28,7 +28,7 @@
*
* @param <T> Value type
*/
public class QIterator0<T> implements PointIterator<T> {
public class BTIterator<T> implements PointIterator<T> {

private class IteratorStack {
private final ArrayList<StackEntry<T>> stack;
Expand Down Expand Up @@ -73,18 +73,20 @@ public void clear() {

private static class StackEntry<T> {
int pos;
BTNode<T>[] subs;
BTNode<T> left;
BTNode<T> right;
ArrayList<PointEntry<T>> vals;
int len;

void set(BTNode<T> node) {
this.pos = 0;
this.vals = node.getEntries();
this.subs = node.getChildNodes();
this.left = node.getLeftChild();
this.right = node.getRightChild();
if (this.vals != null) {
len = this.vals.size();
} else {
len = this.subs.length;
len = 2;
}
}

Expand All @@ -94,7 +96,7 @@ public boolean isLeaf() {
}


QIterator0(BallTree<T> tree, double[] min, double[] max) {
BTIterator(BallTree<T> tree, double[] min, double[] max) {
this.stack = new IteratorStack();
this.tree = tree;
reset(min, max);
Expand All @@ -112,9 +114,8 @@ private void findNext() {
return;
}
} else {
BTNode<T> node = se.subs[pos];
if (node != null &&
BTUtil.overlap(min, max, node.getCenter(), node.getRadius())) {
BTNode<T> node = pos == 0 ? se.left : se.right;
if (BTUtil.overlap(min, max, node.getCenter(), node.getRadius())) {
se = stack.prepareAndPush(node);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2009-2023 Tilmann Zaeschke. All rights reserved.
* Copyright 2016-2024 Tilmann Zaeschke. All rights reserved.
*
* This file is part of TinSpin.
*
Expand All @@ -25,7 +25,7 @@

import static org.tinspin.index.Index.*;

public class QIteratorKnn<T> implements PointIteratorKnn<T> {
public class BTIteratorKnn<T> implements PointIteratorKnn<T> {

private final BTNode<T> root;
private final PointDistance distFn;
Expand All @@ -38,7 +38,7 @@ public class QIteratorKnn<T> implements PointIteratorKnn<T> {
private double[] center;
private double currentDistance;

QIteratorKnn(BTNode<T> root, int minResults, double[] center, PointDistance distFn, PointFilterKnn<T> filterFn) {
BTIteratorKnn(BTNode<T> root, int minResults, double[] center, PointDistance distFn, PointFilterKnn<T> filterFn) {
this.filterFn = filterFn;
this.distFn = distFn;
this.root = root;
Expand Down Expand Up @@ -124,13 +124,15 @@ private void findNextElement() {
}
}
} else {
for (BTNode<T> subnode : node.getChildNodes()) {
if (subnode != null) {
double dist = distFn.dist(center, subnode.getCenter()) - subnode.getRadius();
if (dist <= maxNodeDist) {
queueN.push(new NodeDistT(dist, subnode));
}
}
BTNode leftChild = node.getLeftChild();
double distL = distFn.dist(center, leftChild.getCenter()) - leftChild.getRadius();
if (distL <= maxNodeDist) {
queueN.push(new NodeDistT(distL, leftChild));
}
BTNode rightChild = node.getLeftChild();
double distR = distFn.dist(center, rightChild.getCenter()) - rightChild.getRadius();
if (distR <= maxNodeDist) {
queueN.push(new NodeDistT(distR, rightChild));
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions src/main/java/org/tinspin/index/balltree/BTNode.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2017 Tilmann Zaeschke
* Copyright 2016-2024 Tilmann Zaeschke
*
* This file is part of TinSpin.
*
Expand Down Expand Up @@ -41,7 +41,6 @@ public class BTNode<T> {
private double radius;
// null indicates that we have sub-node i.o. values
private ArrayList<PointEntry<T>> values;
private BTNode<T>[] subs;
private BTNode<T> left;
private BTNode<T> right;

Expand Down Expand Up @@ -89,7 +88,6 @@ BTNode<T> tryPut(PointEntry<T> e, int maxNodeSize, boolean enforceLeaf) {
//split
ArrayList<PointEntry<T>> vals = values;
values = null;
subs = new BTNode[2];
PointEntry<T> start = vals.get(0);
int dims = start.point().length;
double[][] ordered = BTUtil.orderCoordinates(vals);
Expand Down Expand Up @@ -395,7 +393,7 @@ void checkNode(BallTree.BTStats s, BTNode<T> parent, int depth) {
throw new IllegalStateException();
}
}
if (subs != null) {
if (left != null || right != null) {
throw new IllegalStateException();
}
} else {
Expand All @@ -409,8 +407,12 @@ boolean isLeaf() {
return values != null;
}

BTNode<T>[] getChildNodes() {
return subs;
BTNode<T> getLeftChild() {
return left;
}

BTNode<T> getRightChild() {
return right;
}

// TODO add radii as argument
Expand Down
17 changes: 11 additions & 6 deletions src/main/java/org/tinspin/index/balltree/BTUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2017 Tilmann Zaeschke
* Copyright 2016-2024 Tilmann Zaeschke
*
* This file is part of TinSpin.
*
Expand All @@ -23,8 +23,6 @@
import java.util.ArrayList;
import java.util.Arrays;

import static org.tinspin.index.Index.BoxEntry;

class BTUtil {

static final double EPS_MUL = 1.000000001;
Expand Down Expand Up @@ -81,12 +79,19 @@ public static boolean isPointEqual(double[] p1, double[] p2) {
// }

public static boolean overlap(double[] min, double[] max, double[] center, double radius) {
double[] p = new double[min.length];
for (int d = 0; d < min.length; d++) {
if (max[d] < center[d] - radius || min[d] > center[d] + radius) {
return false;
if (center[d] <= min[d]) {
p[d] = min[d];
} else if (center[d] >= max[d]) {
p[d] = max[d];
} else {
p[d] = center[d];
}

}
return true;
// TODO sqr-dist?
return PointDistance.l2(p, center) <= radius;
}

// public static boolean isRectEnclosed(double[] minEnclosed, double[] maxEnclosed, double[] minOuter, double[] maxOuter) {
Expand Down
18 changes: 8 additions & 10 deletions src/main/java/org/tinspin/index/balltree/BallTree.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2017 Tilmann Zaeschke
* Copyright 2016-2024 Tilmann Zaeschke
*
* This file is part of TinSpin.
*
Expand Down Expand Up @@ -306,7 +306,7 @@ public PointIterator<T> queryExactPoint(double[] point) {
@Override
public PointIterator<T> query(double[] min, double[] max) {
//This does not use min/max but is really very basic.
return new QIterator0<>(this, min, max);
return new BTIterator<>(this, min, max);
//return new QIterator<>(this, min, max);
}

Expand All @@ -325,7 +325,7 @@ public PointEntryKnn<T> query1nn(double[] center) {
*/
@Override
public PointIteratorKnn<T> queryKnn(double[] center, int k, PointDistance dist) {
return new QIteratorKnn<>(root, k, center, dist, (e, d) -> true);
return new BTIteratorKnn<>(root, k, center, dist, (e, d) -> true);
}

/**
Expand All @@ -349,13 +349,11 @@ private void toStringTree(StringBuilderLn sb, BTNode<T> node, int depth, int pos
sb.append(" " + Arrays.toString(node.getCenter()));
sb.appendLn("/" + node.getRadius());
prefix += " ";
if (node.getChildNodes() != null) {
for (int i = 0; i < node.getChildNodes().length; i++) {
BTNode<T> sub = node.getChildNodes()[i];
if (sub != null) {
toStringTree(sb, sub, depth+1, i);
}
}
if (node.getLeftChild() != null) {
toStringTree(sb, node.getLeftChild(), depth+1, 0);
}
if (node.getRightChild() != null) {
toStringTree(sb, node.getRightChild(), depth+1, 1);
}
if (node.getEntries() != null) {
for (int i = 0; i < node.getEntries().size(); i++) {
Expand Down
Loading

0 comments on commit 2be2176

Please sign in to comment.