From e56c6d679c9c278ec5b5585f71b0f0f7ab4732e4 Mon Sep 17 00:00:00 2001 From: Daylon Wilkins Date: Thu, 19 Oct 2023 03:32:04 -0700 Subject: [PATCH] Removed some runtime evaluation functionality from parser --- postgres/parser/cache/cache.go | 618 --- postgres/parser/encoding/encoding.go | 2403 +---------- postgres/parser/geo/geodist/geodist.go | 444 -- postgres/parser/geo/geogen/geogen.go | 227 - postgres/parser/geo/geogfn/azimuth.go | 87 - postgres/parser/geo/geogfn/best_projection.go | 154 - postgres/parser/geo/geogfn/covers.go | 319 -- postgres/parser/geo/geogfn/distance.go | 426 -- postgres/parser/geo/geogfn/dwithin.go | 94 - postgres/parser/geo/geogfn/geogfn.go | 36 - postgres/parser/geo/geogfn/geographiclib.go | 37 - postgres/parser/geo/geogfn/intersects.go | 152 - postgres/parser/geo/geogfn/segmentize.go | 113 - .../parser/geo/geogfn/topology_operations.go | 137 - postgres/parser/geo/geogfn/unary_operators.go | 214 - postgres/parser/geo/geoindex/config.pb.go | 1172 ------ postgres/parser/geo/geoindex/config.proto | 79 - postgres/parser/geo/geoindex/geoindex.go | 717 ---- .../parser/geo/geoindex/s2_geography_index.go | 219 - .../parser/geo/geoindex/s2_geometry_index.go | 471 --- .../parser/geo/geomfn/affine_transforms.go | 243 -- postgres/parser/geo/geomfn/azimuth.go | 82 - .../parser/geo/geomfn/binary_predicates.go | 154 - postgres/parser/geo/geomfn/buffer.go | 136 - postgres/parser/geo/geomfn/collections.go | 440 -- postgres/parser/geo/geomfn/coord.go | 71 - postgres/parser/geo/geomfn/de9im.go | 99 - postgres/parser/geo/geomfn/distance.go | 726 ---- postgres/parser/geo/geomfn/envelope.go | 37 - .../parser/geo/geomfn/flip_coordinates.go | 59 - postgres/parser/geo/geomfn/force_layout.go | 138 - postgres/parser/geo/geomfn/geomfn.go | 95 - .../parser/geo/geomfn/linear_reference.go | 80 - postgres/parser/geo/geomfn/linestring.go | 220 - postgres/parser/geo/geomfn/make_geometry.go | 69 - postgres/parser/geo/geomfn/orientation.go | 169 - .../geo/geomfn/remove_repeated_points.go | 134 - postgres/parser/geo/geomfn/reverse.go | 116 - postgres/parser/geo/geomfn/segmentize.go | 97 - .../parser/geo/geomfn/topology_operations.go | 139 - postgres/parser/geo/geomfn/unary_operators.go | 232 -- .../parser/geo/geomfn/unary_predicates.go | 138 - postgres/parser/geo/geomfn/validity_check.go | 88 - .../parser/geo/geosegmentize/geosegmentize.go | 140 - .../parser/interval/btree_based_interval.go | 1156 ------ postgres/parser/interval/bu23.go | 37 - postgres/parser/interval/interval.go | 251 -- .../parser/interval/llrb_based_interval.go | 690 --- postgres/parser/interval/range_group.go | 833 ---- postgres/parser/interval/td234.go | 36 - postgres/parser/ipaddr/ipaddr.go | 17 - postgres/parser/json/encoded.go | 5 +- postgres/parser/json/json.go | 29 +- postgres/parser/kv/emptytxn.go | 18 - postgres/parser/pgdate/zone_cache.go | 5 +- postgres/parser/pgnotice/display_severity.go | 114 - postgres/parser/pgnotice/pgnotice.go | 48 - postgres/parser/privilege/privilege.go | 51 - postgres/parser/protoutil/clone.go | 5 +- postgres/parser/ring/ring_buffer.go | 168 - postgres/parser/sem/tree/aggregate_funcs.go | 58 - postgres/parser/sem/tree/as_of.go | 191 - postgres/parser/sem/tree/casts.go | 862 ---- postgres/parser/sem/tree/constant_eval.go | 82 - postgres/parser/sem/tree/constants.go | 29 - postgres/parser/sem/tree/datum.go | 2668 +----------- postgres/parser/sem/tree/eval.go | 3698 +---------------- postgres/parser/sem/tree/expr.go | 21 - postgres/parser/sem/tree/generators.go | 81 - postgres/parser/sem/tree/indexed_vars.go | 20 - postgres/parser/sem/tree/normalize.go | 1005 ----- postgres/parser/sem/tree/overload.go | 9 - postgres/parser/sem/tree/pgwire_encode.go | 4 +- postgres/parser/sem/tree/regexp_cache.go | 123 - postgres/parser/sem/tree/window_funcs.go | 665 --- postgres/parser/sem/tree/window_funcs_util.go | 243 -- postgres/parser/sessiondata/search_path.go | 10 - postgres/parser/sessiondata/sequence_state.go | 116 - postgres/parser/sessiondata/session_data.go | 418 -- postgres/parser/syncutil/atomic.go | 91 - postgres/parser/syncutil/int_map.go | 413 -- postgres/parser/syncutil/mutex_deadlock.go | 60 - postgres/parser/syncutil/mutex_sync.go | 76 - postgres/parser/syncutil/mutex_sync_race.go | 138 - .../syncutil/singleflight/singleflight.go | 172 - postgres/parser/unique/unique.go | 124 - postgres/parser/uuid/generator.go | 4 +- 87 files changed, 360 insertions(+), 26505 deletions(-) delete mode 100644 postgres/parser/cache/cache.go delete mode 100644 postgres/parser/geo/geodist/geodist.go delete mode 100644 postgres/parser/geo/geogen/geogen.go delete mode 100644 postgres/parser/geo/geogfn/azimuth.go delete mode 100644 postgres/parser/geo/geogfn/best_projection.go delete mode 100644 postgres/parser/geo/geogfn/covers.go delete mode 100644 postgres/parser/geo/geogfn/distance.go delete mode 100644 postgres/parser/geo/geogfn/dwithin.go delete mode 100644 postgres/parser/geo/geogfn/geogfn.go delete mode 100644 postgres/parser/geo/geogfn/geographiclib.go delete mode 100644 postgres/parser/geo/geogfn/intersects.go delete mode 100644 postgres/parser/geo/geogfn/segmentize.go delete mode 100644 postgres/parser/geo/geogfn/topology_operations.go delete mode 100644 postgres/parser/geo/geogfn/unary_operators.go delete mode 100644 postgres/parser/geo/geoindex/config.pb.go delete mode 100644 postgres/parser/geo/geoindex/config.proto delete mode 100644 postgres/parser/geo/geoindex/geoindex.go delete mode 100644 postgres/parser/geo/geoindex/s2_geography_index.go delete mode 100644 postgres/parser/geo/geoindex/s2_geometry_index.go delete mode 100644 postgres/parser/geo/geomfn/affine_transforms.go delete mode 100644 postgres/parser/geo/geomfn/azimuth.go delete mode 100644 postgres/parser/geo/geomfn/binary_predicates.go delete mode 100644 postgres/parser/geo/geomfn/buffer.go delete mode 100644 postgres/parser/geo/geomfn/collections.go delete mode 100644 postgres/parser/geo/geomfn/coord.go delete mode 100644 postgres/parser/geo/geomfn/de9im.go delete mode 100644 postgres/parser/geo/geomfn/distance.go delete mode 100644 postgres/parser/geo/geomfn/envelope.go delete mode 100644 postgres/parser/geo/geomfn/flip_coordinates.go delete mode 100644 postgres/parser/geo/geomfn/force_layout.go delete mode 100644 postgres/parser/geo/geomfn/geomfn.go delete mode 100644 postgres/parser/geo/geomfn/linear_reference.go delete mode 100644 postgres/parser/geo/geomfn/linestring.go delete mode 100644 postgres/parser/geo/geomfn/make_geometry.go delete mode 100644 postgres/parser/geo/geomfn/orientation.go delete mode 100644 postgres/parser/geo/geomfn/remove_repeated_points.go delete mode 100644 postgres/parser/geo/geomfn/reverse.go delete mode 100644 postgres/parser/geo/geomfn/segmentize.go delete mode 100644 postgres/parser/geo/geomfn/topology_operations.go delete mode 100644 postgres/parser/geo/geomfn/unary_operators.go delete mode 100644 postgres/parser/geo/geomfn/unary_predicates.go delete mode 100644 postgres/parser/geo/geomfn/validity_check.go delete mode 100644 postgres/parser/geo/geosegmentize/geosegmentize.go delete mode 100644 postgres/parser/interval/btree_based_interval.go delete mode 100644 postgres/parser/interval/bu23.go delete mode 100644 postgres/parser/interval/interval.go delete mode 100644 postgres/parser/interval/llrb_based_interval.go delete mode 100644 postgres/parser/interval/range_group.go delete mode 100644 postgres/parser/interval/td234.go delete mode 100644 postgres/parser/kv/emptytxn.go delete mode 100644 postgres/parser/pgnotice/display_severity.go delete mode 100644 postgres/parser/pgnotice/pgnotice.go delete mode 100644 postgres/parser/ring/ring_buffer.go delete mode 100644 postgres/parser/sem/tree/aggregate_funcs.go delete mode 100644 postgres/parser/sem/tree/as_of.go delete mode 100644 postgres/parser/sem/tree/constant_eval.go delete mode 100644 postgres/parser/sem/tree/constants.go delete mode 100644 postgres/parser/sem/tree/generators.go delete mode 100644 postgres/parser/sem/tree/normalize.go delete mode 100644 postgres/parser/sem/tree/regexp_cache.go delete mode 100644 postgres/parser/sem/tree/window_funcs.go delete mode 100644 postgres/parser/sem/tree/window_funcs_util.go delete mode 100644 postgres/parser/sessiondata/sequence_state.go delete mode 100644 postgres/parser/sessiondata/session_data.go delete mode 100644 postgres/parser/syncutil/atomic.go delete mode 100644 postgres/parser/syncutil/int_map.go delete mode 100644 postgres/parser/syncutil/mutex_deadlock.go delete mode 100644 postgres/parser/syncutil/mutex_sync.go delete mode 100644 postgres/parser/syncutil/mutex_sync_race.go delete mode 100644 postgres/parser/syncutil/singleflight/singleflight.go delete mode 100644 postgres/parser/unique/unique.go diff --git a/postgres/parser/cache/cache.go b/postgres/parser/cache/cache.go deleted file mode 100644 index f87dfeb1b6..0000000000 --- a/postgres/parser/cache/cache.go +++ /dev/null @@ -1,618 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2014 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. -// -// This code is based on: https://github.com/golang/groupcache/ - -package cache - -import ( - "bytes" - "fmt" - "sync/atomic" - - "github.com/biogo/store/llrb" - - "github.com/dolthub/doltgresql/postgres/parser/interval" -) - -// EvictionPolicy is the cache eviction policy enum. -type EvictionPolicy int - -// Constants describing LRU and FIFO, and None cache eviction policies -// respectively. -const ( - CacheLRU EvictionPolicy = iota // Least recently used - CacheFIFO // First in, first out - CacheNone // No evictions; don't maintain ordering list -) - -// A Config specifies the eviction policy, eviction -// trigger callback, and eviction listener callback. -type Config struct { - // Policy is one of the consts listed for EvictionPolicy. - Policy EvictionPolicy - - // ShouldEvict is a callback function executed each time a new entry - // is added to the cache. It supplies cache size, and potential - // evictee's key and value. The function should return true if the - // entry may be evicted; false otherwise. For example, to support a - // maximum size for the cache, use a method like: - // - // func(size int, key Key, value interface{}) { return size > maxSize } - // - // To support max TTL in the cache, use something like: - // - // func(size int, key Key, value interface{}) { - // return timeutil.Now().UnixNano() - value.(int64) > maxTTLNanos - // } - ShouldEvict func(size int, key, value interface{}) bool - - // OnEvicted optionally specifies a callback function to be - // executed when an entry is purged from the cache. - OnEvicted func(key, value interface{}) - - // OnEvictedEntry optionally specifies a callback function to - // be executed when an entry is purged from the cache. - OnEvictedEntry func(entry *Entry) -} - -// Entry holds the key and value and a pointer to the linked list -// which defines the eviction ordering. -type Entry struct { - Key, Value interface{} - next, prev *Entry -} - -func (e Entry) String() string { - return fmt.Sprintf("%s", e.Key) -} - -// Compare implements the llrb.Comparable interface for cache entries. -// This facility is used by the OrderedCache, and crucially requires -// that keys used with that cache implement llrb.Comparable. -func (e *Entry) Compare(b llrb.Comparable) int { - return e.Key.(llrb.Comparable).Compare(b.(*Entry).Key.(llrb.Comparable)) -} - -// The following methods implement the interval.Interface for entry by casting -// the entry key to an interval key and calling the appropriate accessers. - -// ID implements interval.Interface -func (e *Entry) ID() uintptr { - return e.Key.(*IntervalKey).id -} - -// Range implements interval.Interface -func (e *Entry) Range() interval.Range { - return e.Key.(*IntervalKey).Range -} - -// entryList is a double-linked circular list of *Entry elements. The code is -// derived from the stdlib container/list but customized to Entry in order to -// avoid a separate allocation for every element. -type entryList struct { - root Entry -} - -func (l *entryList) init() { - l.root.next = &l.root - l.root.prev = &l.root -} - -func (l *entryList) back() *Entry { - return l.root.prev -} - -func (l *entryList) insertAfter(e, at *Entry) { - n := at.next - at.next = e - e.prev = at - e.next = n - n.prev = e -} - -func (l *entryList) insertBefore(e, mark *Entry) { - l.insertAfter(e, mark.prev) -} - -func (l *entryList) remove(e *Entry) *Entry { - if e == &l.root { - panic("cannot remove root list node") - } - // TODO(peter): Revert this protection against removing a non-existent entry - // from the list when the cause of - // https://github.com/cockroachdb/cockroach/issues/6190 is determined. Should - // be replaced with an explicit panic instead of the implicit one of a - // nil-pointer dereference. - if e.next != nil { - e.prev.next = e.next - e.next.prev = e.prev - e.next = nil // avoid memory leaks - e.prev = nil // avoid memory leaks - } - return e -} - -func (l *entryList) pushFront(e *Entry) { - l.insertAfter(e, &l.root) -} - -func (l *entryList) moveToFront(e *Entry) { - if l.root.next == e { - return - } - l.insertAfter(l.remove(e), &l.root) -} - -// cacheStore is an interface for the backing store used for the cache. -type cacheStore interface { - // init initializes or clears all entries. - init() - // get returns the entry by key. - get(key interface{}) *Entry - // add stores an entry. - add(e *Entry) - // del removes an entry. - del(e *Entry) - // len is number of items in store. - length() int -} - -// baseCache contains the config, cacheStore interface, and the linked -// list for eviction order. -type baseCache struct { - Config - store cacheStore - ll entryList -} - -func newBaseCache(config Config) baseCache { - return baseCache{ - Config: config, - } -} - -// init initializes the baseCache with the provided cacheStore. It must be -// called with a non-nil cacheStore before use of the cache. -func (bc *baseCache) init(store cacheStore) { - bc.ll.init() - bc.store = store - bc.store.init() -} - -// Add adds a value to the cache. -func (bc *baseCache) Add(key, value interface{}) { - bc.add(key, value, nil, nil) -} - -// AddEntry adds a value to the cache. It provides an alternative interface to -// Add which the caller can use to reduce allocations by bundling the Entry -// structure with the key and value to be stored. -func (bc *baseCache) AddEntry(entry *Entry) { - bc.add(entry.Key, entry.Value, entry, nil) -} - -// AddEntryAfter adds a value to the cache, making sure that it is placed after -// the second entry in the eviction queue. It provides an alternative interface to -// Add which the caller can use to reduce allocations by bundling the Entry -// structure with the key and value to be stored. -func (bc *baseCache) AddEntryAfter(entry, after *Entry) { - bc.add(entry.Key, entry.Value, entry, after) -} - -// MoveToEnd moves the entry to the end of the eviction queue. -func (bc *baseCache) MoveToEnd(entry *Entry) { - bc.ll.moveToFront(entry) -} - -func (bc *baseCache) add(key, value interface{}, entry, after *Entry) { - if e := bc.store.get(key); e != nil { - bc.access(e) - e.Value = value - return - } - e := entry - if e == nil { - e = &Entry{Key: key, Value: value} - } - if after != nil { - bc.ll.insertBefore(e, after) - } else { - bc.ll.pushFront(e) - } - bc.store.add(e) - // Evict as many elements as we can. - for bc.evict() { - } -} - -// Get looks up a key's value from the cache. -func (bc *baseCache) Get(key interface{}) (value interface{}, ok bool) { - if e := bc.store.get(key); e != nil { - bc.access(e) - return e.Value, true - } - return -} - -// Del removes the provided key from the cache. -func (bc *baseCache) Del(key interface{}) { - e := bc.store.get(key) - bc.DelEntry(e) -} - -// DelEntry removes the provided entry from the cache. -func (bc *baseCache) DelEntry(entry *Entry) { - if entry != nil { - bc.removeElement(entry) - } -} - -// Clear clears all entries from the cache. -func (bc *baseCache) Clear() { - if bc.OnEvicted != nil || bc.OnEvictedEntry != nil { - for e := bc.ll.back(); e != &bc.ll.root; e = e.prev { - if bc.OnEvicted != nil { - bc.OnEvicted(e.Key, e.Value) - } - if bc.OnEvictedEntry != nil { - bc.OnEvictedEntry(e) - } - } - } - bc.ll.init() - bc.store.init() -} - -// Len returns the number of items in the cache. -func (bc *baseCache) Len() int { - return bc.store.length() -} - -func (bc *baseCache) access(e *Entry) { - if bc.Policy == CacheLRU { - bc.ll.moveToFront(e) - } -} - -func (bc *baseCache) removeElement(e *Entry) { - bc.ll.remove(e) - bc.store.del(e) - if bc.OnEvicted != nil { - bc.OnEvicted(e.Key, e.Value) - } - if bc.OnEvictedEntry != nil { - bc.OnEvictedEntry(e) - } -} - -// evict removes the oldest item from the cache for FIFO and -// the least recently used item for LRU. Returns true if an -// entry was evicted, false otherwise. -func (bc *baseCache) evict() bool { - if bc.ShouldEvict == nil || bc.Policy == CacheNone { - return false - } - l := bc.store.length() - if l > 0 { - e := bc.ll.back() - if bc.ShouldEvict(l, e.Key, e.Value) { - bc.removeElement(e) - return true - } - } - return false -} - -// UnorderedCache is a cache which supports custom eviction triggers and two -// eviction policies: LRU and FIFO. A listener pattern is available -// for eviction events. This cache uses a hashmap for storing elements, -// making it the most performant. Only exact lookups are supported. -// -// UnorderedCache requires that keys are comparable, according to the go -// specification (http://golang.org/ref/spec#Comparison_operators). -// -// UnorderedCache is not safe for concurrent access. -type UnorderedCache struct { - baseCache - hmap map[interface{}]interface{} -} - -// NewUnorderedCache creates a new UnorderedCache backed by a hash map. -func NewUnorderedCache(config Config) *UnorderedCache { - mc := &UnorderedCache{ - baseCache: newBaseCache(config), - } - mc.baseCache.init(mc) - return mc -} - -// Implementation of cacheStore interface. -func (mc *UnorderedCache) init() { - mc.hmap = make(map[interface{}]interface{}) -} -func (mc *UnorderedCache) get(key interface{}) *Entry { - if e, ok := mc.hmap[key].(*Entry); ok { - return e - } - return nil -} -func (mc *UnorderedCache) add(e *Entry) { - mc.hmap[e.Key] = e -} -func (mc *UnorderedCache) del(e *Entry) { - delete(mc.hmap, e.Key) -} -func (mc *UnorderedCache) length() int { - return len(mc.hmap) -} - -// OrderedCache is a cache which supports binary searches using Ceil -// and Floor methods. It is backed by a left-leaning red black tree. -// See comments in UnorderedCache for more details on cache functionality. -// -// OrderedCache requires that keys implement llrb.Comparable. -// -// OrderedCache is not safe for concurrent access. -type OrderedCache struct { - baseCache - llrb llrb.Tree -} - -// NewOrderedCache creates a new Cache backed by a left-leaning red -// black binary tree which supports binary searches via the Ceil() and -// Floor() methods. See NewUnorderedCache() for details on parameters. -func NewOrderedCache(config Config) *OrderedCache { - oc := &OrderedCache{ - baseCache: newBaseCache(config), - } - oc.baseCache.init(oc) - return oc -} - -// Implementation of cacheStore interface. -func (oc *OrderedCache) init() { - oc.llrb = llrb.Tree{} -} -func (oc *OrderedCache) get(key interface{}) *Entry { - if e, ok := oc.llrb.Get(&Entry{Key: key}).(*Entry); ok { - return e - } - return nil -} -func (oc *OrderedCache) add(e *Entry) { - oc.llrb.Insert(e) -} -func (oc *OrderedCache) del(e *Entry) { - oc.llrb.Delete(e) -} -func (oc *OrderedCache) length() int { - return oc.llrb.Len() -} - -// CeilEntry returns the smallest cache entry greater than or equal to key. -func (oc *OrderedCache) CeilEntry(key interface{}) (*Entry, bool) { - if e, ok := oc.llrb.Ceil(&Entry{Key: key}).(*Entry); ok { - return e, true - } - return nil, false -} - -// Ceil returns the smallest key-value pair greater than or equal to key. -func (oc *OrderedCache) Ceil(key interface{}) (interface{}, interface{}, bool) { - if e, ok := oc.CeilEntry(key); ok { - return e.Key, e.Value, true - } - return nil, nil, false -} - -// FloorEntry returns the greatest cache entry less than or equal to key. -func (oc *OrderedCache) FloorEntry(key interface{}) (*Entry, bool) { - if e, ok := oc.llrb.Floor(&Entry{Key: key}).(*Entry); ok { - return e, true - } - return nil, false -} - -// Floor returns the greatest key-value pair less than or equal to key. -func (oc *OrderedCache) Floor(key interface{}) (interface{}, interface{}, bool) { - if e, ok := oc.FloorEntry(key); ok { - return e.Key, e.Value, true - } - return nil, nil, false -} - -// DoEntry invokes f on all cache entries in the cache. f returns a boolean -// indicating the traversal is done. If f returns true, the DoEntry loop will -// exit; false, it will continue. DoEntry returns whether the iteration exited -// early. -func (oc *OrderedCache) DoEntry(f func(e *Entry) bool) bool { - return oc.llrb.Do(func(e llrb.Comparable) bool { - return f(e.(*Entry)) - }) -} - -// Do invokes f on all key-value pairs in the cache. f returns a boolean -// indicating the traversal is done. If f returns true, the Do loop will exit; -// false, it will continue. Do returns whether the iteration exited early. -func (oc *OrderedCache) Do(f func(k, v interface{}) bool) bool { - return oc.DoEntry(func(e *Entry) bool { - return f(e.Key, e.Value) - }) -} - -// DoRangeEntry invokes f on all cache entries in the range of from -> to. f -// returns a boolean indicating the traversal is done. If f returns true, the -// DoRangeEntry loop will exit; false, it will continue. DoRangeEntry returns -// whether the iteration exited early. -func (oc *OrderedCache) DoRangeEntry(f func(e *Entry) bool, from, to interface{}) bool { - return oc.llrb.DoRange(func(e llrb.Comparable) bool { - return f(e.(*Entry)) - }, &Entry{Key: from}, &Entry{Key: to}) -} - -// DoRangeReverseEntry invokes f on all cache entries in the range (to, from]. from -// should be higher than to. -// f returns a boolean indicating the traversal is done. If f returns true, the -// DoRangeReverseEntry loop will exit; false, it will continue. -// DoRangeReverseEntry returns whether the iteration exited early. -func (oc *OrderedCache) DoRangeReverseEntry(f func(e *Entry) bool, from, to interface{}) bool { - return oc.llrb.DoRangeReverse(func(e llrb.Comparable) bool { - return f(e.(*Entry)) - }, &Entry{Key: from}, &Entry{Key: to}) -} - -// DoRange invokes f on all key-value pairs in the range of from -> to. f -// returns a boolean indicating the traversal is done. If f returns true, the -// DoRange loop will exit; false, it will continue. DoRange returns whether the -// iteration exited early. -func (oc *OrderedCache) DoRange(f func(k, v interface{}) bool, from, to interface{}) bool { - return oc.DoRangeEntry(func(e *Entry) bool { - return f(e.Key, e.Value) - }, from, to) -} - -// IntervalCache is a cache which supports querying of intervals which -// match a key or range of keys. It is backed by an interval tree. See -// comments in UnorderedCache for more details on cache functionality. -// -// Note that the IntervalCache allow multiple identical segments, as -// specified by start and end keys. -// -// Keys supplied to the IntervalCache's Get, Add & Del methods must be -// constructed from IntervalCache.NewKey(). -// -// IntervalCache is not safe for concurrent access. -type IntervalCache struct { - baseCache - tree interval.Tree - - // The fields below are used to avoid allocations during get, del and - // GetOverlaps. - getID uintptr - getEntry *Entry - overlapKey IntervalKey - overlaps []*Entry -} - -// IntervalKey provides uniqueness as well as key interval. -type IntervalKey struct { - interval.Range - id uintptr -} - -var intervalAlloc int64 - -func (ik IntervalKey) String() string { - return fmt.Sprintf("%d: %q-%q", ik.id, ik.Start, ik.End) -} - -// NewIntervalCache creates a new Cache backed by an interval tree. -// See NewCache() for details on parameters. -func NewIntervalCache(config Config) *IntervalCache { - ic := &IntervalCache{ - baseCache: newBaseCache(config), - } - ic.baseCache.init(ic) - return ic -} - -// NewKey creates a new interval key defined by start and end values. -func (ic *IntervalCache) NewKey(start, end []byte) *IntervalKey { - k := ic.MakeKey(start, end) - return &k -} - -// MakeKey creates a new interval key defined by start and end values. -func (ic *IntervalCache) MakeKey(start, end []byte) IntervalKey { - if bytes.Compare(start, end) >= 0 { - panic(fmt.Sprintf("start key greater than or equal to end key %q >= %q", start, end)) - } - return IntervalKey{ - Range: interval.Range{ - Start: interval.Comparable(start), - End: interval.Comparable(end), - }, - id: uintptr(atomic.AddInt64(&intervalAlloc, 1)), - } -} - -// Implementation of cacheStore interface. -func (ic *IntervalCache) init() { - ic.tree = interval.NewTree(interval.ExclusiveOverlapper) -} - -func (ic *IntervalCache) get(key interface{}) *Entry { - ik := key.(*IntervalKey) - ic.getID = ik.id - ic.tree.DoMatching(ic.doGet, ik.Range) - e := ic.getEntry - ic.getEntry = nil - return e -} - -func (ic *IntervalCache) doGet(i interval.Interface) bool { - e := i.(*Entry) - if e.ID() == ic.getID { - ic.getEntry = e - return true - } - return false -} - -func (ic *IntervalCache) add(e *Entry) { - if err := ic.tree.Insert(e, false); err != nil { - fmt.Printf("Cache Error: %s\n", err.Error()) - } -} - -func (ic *IntervalCache) del(e *Entry) { - if err := ic.tree.Delete(e, false); err != nil { - fmt.Printf("Cache Error: %s\n", err.Error()) - } -} - -func (ic *IntervalCache) length() int { - return ic.tree.Len() -} - -// GetOverlaps returns a slice of values which overlap the specified -// interval. The slice is only valid until the next call to GetOverlaps. -func (ic *IntervalCache) GetOverlaps(start, end []byte) []*Entry { - ic.overlapKey.Range = interval.Range{ - Start: interval.Comparable(start), - End: interval.Comparable(end), - } - ic.tree.DoMatching(ic.doOverlaps, ic.overlapKey.Range) - overlaps := ic.overlaps - ic.overlaps = ic.overlaps[:0] - return overlaps -} - -func (ic *IntervalCache) doOverlaps(i interval.Interface) bool { - e := i.(*Entry) - ic.access(e) // maintain cache eviction ordering - ic.overlaps = append(ic.overlaps, e) - return false -} diff --git a/postgres/parser/encoding/encoding.go b/postgres/parser/encoding/encoding.go index 3ec340e9f4..30c0067cc4 100644 --- a/postgres/parser/encoding/encoding.go +++ b/postgres/parser/encoding/encoding.go @@ -27,28 +27,13 @@ package encoding import ( "bytes" "encoding/binary" - "encoding/hex" "fmt" - "math" "reflect" - "strconv" - "strings" - "time" - "unicode" - "unicode/utf8" "unsafe" "github.com/cockroachdb/apd/v2" "github.com/cockroachdb/errors" - "github.com/dolthub/doltgresql/postgres/parser/duration" - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" - "github.com/dolthub/doltgresql/postgres/parser/ipaddr" - "github.com/dolthub/doltgresql/postgres/parser/protoutil" - "github.com/dolthub/doltgresql/postgres/parser/timeofday" - "github.com/dolthub/doltgresql/postgres/parser/timetz" - "github.com/dolthub/doltgresql/postgres/parser/utils" "github.com/dolthub/doltgresql/postgres/parser/uuid" ) @@ -115,14 +100,6 @@ const ( arrayKeyTerminator byte = 0x00 arrayKeyDescendingTerminator byte = 0xFF - // We use different null encodings for nulls within key arrays. - // Doing this allows for the terminator to be less/greater than - // the null value within arrays. These byte values overlap with - // encodedNotNull, encodedNotNullDesc, and interleavedSentinel, - // but they can only exist within an encoded array key. Because - // of the context, they cannot be ambiguous with these other bytes. - ascendingNullWithinArrayKey byte = 0x01 - descendingNullWithinArrayKey byte = 0xFE // IntMin is chosen such that the range of int tags does not overlap the // ascii character set that is frequently used in testing. @@ -137,29 +114,7 @@ const ( // This value is not actually ever present in a stored key, but // it's used in keys used as span boundaries for index scans. encodedNotNullDesc = 0xfe - // interleavedSentinel uses the same byte as encodedNotNullDesc. - // It is used in the key encoding of interleaved index keys in order - // to coerce the key to sort after its respective parent and ancestors' - // index keys. - // The byte for NotNullDesc was chosen over NullDesc since NotNullDesc - // is never used in actual encoded keys. - // This allowed the key pretty printer for interleaved keys to work - // without table descriptors. - interleavedSentinel = 0xfe - encodedNullDesc = 0xff - - // offsetSecsToMicros is a constant that allows conversion from seconds - // to microseconds for offsetSecs type calculations (e.g. for TimeTZ). - offsetSecsToMicros = 1000000 -) - -const ( - // EncodedDurationMaxLen is the largest number of bytes used when encoding a - // Duration. - EncodedDurationMaxLen = 1 + 3*binary.MaxVarintLen64 // 3 varints are encoded. - // EncodedTimeTZMaxLen is the largest number of bytes used when encoding a - // TimeTZ. - EncodedTimeTZMaxLen = 1 + binary.MaxVarintLen64 + binary.MaxVarintLen32 + encodedNullDesc = 0xff ) // Direction for ordering results. @@ -195,12 +150,6 @@ func PutUint32Ascending(b []byte, v uint32, idx int) []byte { return b } -// EncodeUint32Descending encodes the uint32 value so that it sorts in -// reverse order, from largest to smallest. -func EncodeUint32Descending(b []byte, v uint32) []byte { - return EncodeUint32Ascending(b, ^v) -} - // DecodeUint32Ascending decodes a uint32 from the input buffer, treating // the input as a big-endian 4 byte uint32 representation. The remainder // of the input buffer and the decoded uint32 are returned. @@ -212,13 +161,6 @@ func DecodeUint32Ascending(b []byte) ([]byte, uint32, error) { return b[4:], v, nil } -// DecodeUint32Descending decodes a uint32 value which was encoded -// using EncodeUint32Descending. -func DecodeUint32Descending(b []byte) ([]byte, uint32, error) { - leftover, v, err := DecodeUint32Ascending(b) - return leftover, ^v, err -} - const uint64AscendingEncodedLength = 8 // EncodeUint64Ascending encodes the uint64 value using a big-endian 8 byte @@ -230,12 +172,6 @@ func EncodeUint64Ascending(b []byte, v uint64) []byte { byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) } -// EncodeUint64Descending encodes the uint64 value so that it sorts in -// reverse order, from largest to smallest. -func EncodeUint64Descending(b []byte, v uint64) []byte { - return EncodeUint64Ascending(b, ^v) -} - // DecodeUint64Ascending decodes a uint64 from the input buffer, treating // the input as a big-endian 8 byte uint64 representation. The remainder // of the input buffer and the decoded uint64 are returned. @@ -247,13 +183,6 @@ func DecodeUint64Ascending(b []byte) ([]byte, uint64, error) { return b[8:], v, nil } -// DecodeUint64Descending decodes a uint64 value which was encoded -// using EncodeUint64Descending. -func DecodeUint64Descending(b []byte) ([]byte, uint64, error) { - leftover, v, err := DecodeUint64Ascending(b) - return leftover, ^v, err -} - // MaxVarintLen is the maximum length of a value encoded using any of: // - EncodeVarintAscending // - EncodeVarintDescending @@ -261,46 +190,6 @@ func DecodeUint64Descending(b []byte) ([]byte, uint64, error) { // - EncodeUvarintDescending const MaxVarintLen = 9 -// EncodeVarintAscending encodes the int64 value using a variable length -// (length-prefixed) representation. The length is encoded as a single -// byte. If the value to be encoded is negative the length is encoded -// as 8-numBytes. If the value is positive it is encoded as -// 8+numBytes. The encoded bytes are appended to the supplied buffer -// and the final buffer is returned. -func EncodeVarintAscending(b []byte, v int64) []byte { - if v < 0 { - switch { - case v >= -0xff: - return append(b, IntMin+7, byte(v)) - case v >= -0xffff: - return append(b, IntMin+6, byte(v>>8), byte(v)) - case v >= -0xffffff: - return append(b, IntMin+5, byte(v>>16), byte(v>>8), byte(v)) - case v >= -0xffffffff: - return append(b, IntMin+4, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) - case v >= -0xffffffffff: - return append(b, IntMin+3, byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), - byte(v)) - case v >= -0xffffffffffff: - return append(b, IntMin+2, byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), - byte(v>>8), byte(v)) - case v >= -0xffffffffffffff: - return append(b, IntMin+1, byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), - byte(v>>16), byte(v>>8), byte(v)) - default: - return append(b, IntMin, byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), - byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) - } - } - return EncodeUvarintAscending(b, uint64(v)) -} - -// EncodeVarintDescending encodes the int64 value so that it sorts in reverse -// order, from largest to smallest. -func EncodeVarintDescending(b []byte, v int64) []byte { - return EncodeVarintAscending(b, ^v) -} - // getVarintLen returns the encoded length of an encoded varint. Assumes the // slice has at least one byte. func getVarintLen(b []byte) (int, error) { @@ -323,45 +212,6 @@ func getVarintLen(b []byte) (int, error) { return length, nil } -// DecodeVarintAscending decodes a value encoded by EncodeVaringAscending. -func DecodeVarintAscending(b []byte) ([]byte, int64, error) { - if len(b) == 0 { - return nil, 0, errors.Errorf("insufficient bytes to decode uvarint value") - } - length := int(b[0]) - intZero - if length < 0 { - length = -length - remB := b[1:] - if len(remB) < length { - return nil, 0, errors.Errorf("insufficient bytes to decode uvarint value: %q", remB) - } - var v int64 - // Use the ones-complement of each encoded byte in order to build - // up a positive number, then take the ones-complement again to - // arrive at our negative value. - for _, t := range remB[:length] { - v = (v << 8) | int64(^t) - } - return remB[length:], ^v, nil - } - - remB, v, err := DecodeUvarintAscending(b) - if err != nil { - return remB, 0, err - } - if v > math.MaxInt64 { - return nil, 0, errors.Errorf("varint %d overflows int64", v) - } - return remB, int64(v), nil -} - -// DecodeVarintDescending decodes a uint64 value which was encoded -// using EncodeVarintDescending. -func DecodeVarintDescending(b []byte) ([]byte, int64, error) { - leftover, v, err := DecodeVarintAscending(b) - return leftover, ^v, err -} - // EncodeUvarintAscending encodes the uint64 value using a variable length // (length-prefixed) representation. The length is encoded as a single // byte indicating the number of encoded bytes (-8) to follow. See @@ -540,15 +390,6 @@ var ( descendingGeoEscapes = escapes{^escape, ^escapedTerm, ^escaped00, ^escapedFF, geoDescMarker} ) -// EncodeBytesAscending encodes the []byte value using an escape-based -// encoding. The encoded value is terminated with the sequence -// "\x00\x01" which is guaranteed to not occur elsewhere in the -// encoded value. The encoded bytes are append to the supplied buffer -// and the resulting buffer is returned. -func EncodeBytesAscending(b []byte, data []byte) []byte { - return encodeBytesAscendingWithTerminatorAndPrefix(b, data, ascendingBytesEscapes.escapedTerm, bytesMarker) -} - // encodeBytesAscendingWithTerminatorAndPrefix encodes the []byte value using an escape-based // encoding. The encoded value is terminated with the sequence // "\x00\terminator". The encoded bytes are append to the supplied buffer @@ -588,77 +429,6 @@ func encodeBytesAscendingWithoutTerminatorOrPrefix(b []byte, data []byte) []byte return append(b, data...) } -// EncodeBytesDescending encodes the []byte value using an -// escape-based encoding and then inverts (ones complement) the result -// so that it sorts in reverse order, from larger to smaller -// lexicographically. -func EncodeBytesDescending(b []byte, data []byte) []byte { - n := len(b) - b = EncodeBytesAscending(b, data) - b[n] = bytesDescMarker - onesComplement(b[n+1:]) - return b -} - -// DecodeBytesAscending decodes a []byte value from the input buffer -// which was encoded using EncodeBytesAscending. The decoded bytes -// are appended to r. The remainder of the input buffer and the -// decoded []byte are returned. -func DecodeBytesAscending(b []byte, r []byte) ([]byte, []byte, error) { - return decodeBytesInternal(b, r, ascendingBytesEscapes, true /* expectMarker */) -} - -// DecodeBytesDescending decodes a []byte value from the input buffer -// which was encoded using EncodeBytesDescending. The decoded bytes -// are appended to r. The remainder of the input buffer and the -// decoded []byte are returned. -func DecodeBytesDescending(b []byte, r []byte) ([]byte, []byte, error) { - // Always pass an `r` to make sure we never get back a sub-slice of `b`, - // since we're going to modify the contents of the slice. - if r == nil { - r = []byte{} - } - b, r, err := decodeBytesInternal(b, r, descendingBytesEscapes, true /* expectMarker */) - onesComplement(r) - return b, r, err -} - -func decodeBytesInternal(b []byte, r []byte, e escapes, expectMarker bool) ([]byte, []byte, error) { - if expectMarker { - if len(b) == 0 || b[0] != e.marker { - return nil, nil, errors.Errorf("did not find marker %#x in buffer %#x", e.marker, b) - } - b = b[1:] - } - - for { - i := bytes.IndexByte(b, e.escape) - if i == -1 { - return nil, nil, errors.Errorf("did not find terminator %#x in buffer %#x", e.escape, b) - } - if i+1 >= len(b) { - return nil, nil, errors.Errorf("malformed escape in buffer %#x", b) - } - v := b[i+1] - if v == e.escapedTerm { - if r == nil { - r = b[:i] - } else { - r = append(r, b[:i]...) - } - return b[i+2:], r, nil - } - - if v != e.escaped00 { - return nil, nil, errors.Errorf("unknown escape sequence: %#x %#x", e.escape, v) - } - - r = append(r, b[:i]...) - r = append(r, e.escapedFF) - b = b[i+2:] - } -} - // getBytesLength finds the length of a bytes encoding. func getBytesLength(b []byte, e escapes) (int, error) { // Skip the tag. @@ -678,46 +448,6 @@ func getBytesLength(b []byte, e escapes) (int, error) { } } -// prettyPrintInvertedIndexKey returns a string representation of the path part of a JSON inverted -// index. -func prettyPrintInvertedIndexKey(b []byte) (string, []byte, error) { - outBytes := "" - // We're skipping the first byte because it's the JSON tag. - tempB := b[1:] - for { - i := bytes.IndexByte(tempB, escape) - - if i == -1 { - return "", nil, errors.Errorf("did not find terminator %#x in buffer %#x", escape, b) - } - if i+1 >= len(tempB) { - return "", nil, errors.Errorf("malformed escape in buffer %#x", b) - } - - switch tempB[i+1] { - case escapedTerm: - if len(tempB[:i]) > 0 { - outBytes = outBytes + strconv.Quote(unsafeString(tempB[:i])) - } else { - lenOut := len(outBytes) - if lenOut > 1 && outBytes[lenOut-1] == '/' { - outBytes = outBytes[:lenOut-1] - } - } - return outBytes, tempB[i+escapeLength:], nil - case escapedJSONObjectKeyTerm: - outBytes = outBytes + strconv.Quote(unsafeString(tempB[:i])) + "/" - case escapedJSONArray: - outBytes = outBytes + "Arr/" - default: - return "", nil, errors.Errorf("malformed escape in buffer %#x", b) - - } - - tempB = tempB[i+escapeLength:] - } -} - // UnsafeConvertStringToBytes converts a string to a byte array to be used with // string encoding functions. Note that the output byte array should not be // modified if the input string is expected to be used again - doing so could @@ -783,52 +513,6 @@ func EncodeJSONEmptyObject(b []byte) []byte { return append(b, escape, escapedTerm, jsonEmptyObject) } -// EncodeStringDescending is the descending version of EncodeStringAscending. -func EncodeStringDescending(b []byte, s string) []byte { - if len(s) == 0 { - return EncodeBytesDescending(b, nil) - } - // We unsafely convert the string to a []byte to avoid the - // usual allocation when converting to a []byte. This is - // kosher because we know that EncodeBytes{,Descending} does - // not keep a reference to the value it encodes. The first - // step is getting access to the string internals. - hdr := (*reflect.StringHeader)(unsafe.Pointer(&s)) - // Next we treat the string data as a maximally sized array which we - // slice. This usage is safe because the pointer value remains in the string. - arg := (*[0x7fffffff]byte)(unsafe.Pointer(hdr.Data))[:len(s):len(s)] - return EncodeBytesDescending(b, arg) -} - -// unsafeString performs an unsafe conversion from a []byte to a string. The -// returned string will share the underlying memory with the []byte which thus -// allows the string to be mutable through the []byte. We're careful to use -// this method only in situations in which the []byte will not be modified. -func unsafeString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) -} - -// DecodeUnsafeStringAscending decodes a string value from the input buffer which was -// encoded using EncodeString or EncodeBytes. The r []byte is used as a -// temporary buffer in order to avoid memory allocations. The remainder of the -// input buffer and the decoded string are returned. Note that the returned -// string may share storage with the input buffer. -func DecodeUnsafeStringAscending(b []byte, r []byte) ([]byte, string, error) { - b, r, err := DecodeBytesAscending(b, r) - return b, unsafeString(r), err -} - -// DecodeUnsafeStringDescending decodes a string value from the input buffer which -// was encoded using EncodeStringDescending or EncodeBytesDescending. The r -// []byte is used as a temporary buffer in order to avoid memory -// allocations. The remainder of the input buffer and the decoded string are -// returned. Note that the returned string may share storage with the input -// buffer. -func DecodeUnsafeStringDescending(b []byte, r []byte) ([]byte, string, error) { - b, r, err := DecodeBytesDescending(b, r) - return b, unsafeString(r), err -} - // EncodeNullAscending encodes a NULL value. The encodes bytes are appended to the // supplied buffer and the final buffer is returned. The encoded value for a // NULL is guaranteed to not be a prefix for the EncodeVarint, EncodeFloat, @@ -843,18 +527,6 @@ func EncodeJSONAscending(b []byte) []byte { return append(b, jsonInvertedIndex) } -// EncodeNullDescending is the descending equivalent of EncodeNullAscending. -func EncodeNullDescending(b []byte) []byte { - return append(b, encodedNullDesc) -} - -// EncodeNotNullAscending encodes a value that is larger than the NULL marker encoded by -// EncodeNull but less than any encoded value returned by EncodeVarint, -// EncodeFloat, EncodeBytes or EncodeString. -func EncodeNotNullAscending(b []byte) []byte { - return append(b, encodedNotNull) -} - // EncodeArrayAscending encodes a value used to signify membership of an array for JSON objects. func EncodeArrayAscending(b []byte) []byte { return append(b, escape, escapedJSONArray) @@ -870,711 +542,155 @@ func EncodeFalseAscending(b []byte) []byte { return append(b, byte(False)) } -// EncodeNotNullDescending is the descending equivalent of EncodeNotNullAscending. -func EncodeNotNullDescending(b []byte) []byte { - return append(b, encodedNotNullDesc) -} - -// EncodeInterleavedSentinel encodes an interleavedSentinel that is necessary -// for interleaved indexes and their index keys. -// The interleavedSentinel has a byte value 0xfe and is equivalent to -// encodedNotNullDesc. -func EncodeInterleavedSentinel(b []byte) []byte { - return append(b, interleavedSentinel) -} - -// DecodeIfNull decodes a NULL value from the input buffer. If the input buffer -// contains a null at the start of the buffer then it is removed from the -// buffer and true is returned for the second result. Otherwise, the buffer is -// returned unchanged and false is returned for the second result. Since the -// NULL value encoding is guaranteed to never occur as the prefix for the -// EncodeVarint, EncodeFloat, EncodeBytes and EncodeString encodings, it is -// safe to call DecodeIfNull on their encoded values. -// This function handles both ascendingly and descendingly encoded NULLs. -func DecodeIfNull(b []byte) ([]byte, bool) { - if PeekType(b) == Null { - return b[1:], true - } - return b, false -} - -// DecodeIfNotNull decodes a not-NULL value from the input buffer. If the input -// buffer contains a not-NULL marker at the start of the buffer then it is -// removed from the buffer and true is returned for the second -// result. Otherwise, the buffer is returned unchanged and false is returned -// for the second result. Note that the not-NULL marker is identical to the -// empty string encoding, so do not use this routine where it is necessary to -// distinguish not-NULL from the empty string. -// This function handles both ascendingly and descendingly encoded NULLs. -func DecodeIfNotNull(b []byte) ([]byte, bool) { - if PeekType(b) == NotNull { - return b[1:], true +// getBitArrayWordsLen returns the number of bit array words in the +// encoded bytes and the size in bytes of the encoded word array +// (excluding the terminator byte). +func getBitArrayWordsLen(b []byte, term byte) (int, int, error) { + bSearch := b + numWords := 0 + sz := 0 + for { + if len(bSearch) == 0 { + return 0, 0, errors.Errorf("slice too short for bit array (%d)", len(b)) + } + if bSearch[0] == term { + break + } + vLen, err := getVarintLen(bSearch) + if err != nil { + return 0, 0, err + } + bSearch = bSearch[vLen:] + numWords++ + sz += vLen } - return b, false + return numWords, sz, nil } -// DecodeIfNotNullDescending decodes encodedNotNullDesc from the input buffer -// and returns the remaining buffer without the sentinel if encodedNotNullDesc -// is the first byte. -// Otherwise, the buffer is returned unchanged and false is returned. -func DecodeIfNotNullDescending(b []byte) ([]byte, bool) { - if len(b) == 0 { - return b, false - } - - if b[0] == encodedNotNullDesc { - return b[1:], true - } - - return b, false -} +// Type represents the type of a value encoded by +// Encode{Null,NotNull,Varint,Uvarint,Float,Bytes}. +//go:generate stringer -type=Type +type Type int -// DecodeIfInterleavedSentinel decodes the interleavedSentinel from the input -// buffer and returns the remaining buffer without the sentinel if the -// interleavedSentinel is the first byte. -// Otherwise, the buffer is returned unchanged and false is returned. -func DecodeIfInterleavedSentinel(b []byte) ([]byte, bool) { - // The interleavedSentinel is equivalent to encodedNotNullDesc - return DecodeIfNotNullDescending(b) -} +// Type values. +// TODO(dan, arjun): Make this into a proto enum. +// The 'Type' annotations are necessary for producing stringer-generated values. +const ( + Unknown Type = 0 + Null Type = 1 + NotNull Type = 2 + Int Type = 3 + Float Type = 4 + Decimal Type = 5 + Bytes Type = 6 + BytesDesc Type = 7 // Bytes encoded descendingly + Time Type = 8 + Duration Type = 9 + True Type = 10 + False Type = 11 + UUID Type = 12 + Array Type = 13 + IPAddr Type = 14 + // SentinelType is used for bit manipulation to check if the encoded type + // value requires more than 4 bits, and thus will be encoded in two bytes. It + // is not used as a type value, and thus intentionally overlaps with the + // subsequent type value. The 'Type' annotation is intentionally omitted here. + SentinelType = 15 + JSON Type = 15 + Tuple Type = 16 + BitArray Type = 17 + BitArrayDesc Type = 18 // BitArray encoded descendingly + TimeTZ Type = 19 + Geo Type = 20 + GeoDesc Type = 21 + ArrayKeyAsc Type = 22 // Array key encoding + ArrayKeyDesc Type = 23 // Array key encoded descendingly + Box2D Type = 24 +) -// EncodeTimeAscending encodes a time value, appends it to the supplied buffer, -// and returns the final buffer. The encoding is guaranteed to be ordered -// Such that if t1.Before(t2) then after EncodeTime(b1, t1), and -// EncodeTime(b2, t2), Compare(b1, b2) < 0. The time zone offset not -// included in the encoding. -func EncodeTimeAscending(b []byte, t time.Time) []byte { - return encodeTime(b, t.Unix(), int64(t.Nanosecond())) -} +// typMap maps an encoded type byte to a decoded Type. It's got 256 slots, one +// for every possible byte value. +var typMap [256]Type -// EncodeTimeDescending is the descending version of EncodeTimeAscending. -func EncodeTimeDescending(b []byte, t time.Time) []byte { - return encodeTime(b, ^t.Unix(), ^int64(t.Nanosecond())) +func init() { + buf := []byte{0} + for i := range typMap { + buf[0] = byte(i) + typMap[i] = slowPeekType(buf) + } } -func encodeTime(b []byte, unix, nanos int64) []byte { - // Read the unix absolute time. This is the absolute time and is - // not time zone offset dependent. - b = append(b, timeMarker) - b = EncodeVarintAscending(b, unix) - b = EncodeVarintAscending(b, nanos) - return b +// PeekType peeks at the type of the value encoded at the start of b. +func PeekType(b []byte) Type { + if len(b) >= 1 { + return typMap[b[0]] + } + return Unknown } -// DecodeTimeAscending decodes a time.Time value which was encoded using -// EncodeTime. The remainder of the input buffer and the decoded -// time.Time are returned. -func DecodeTimeAscending(b []byte) ([]byte, time.Time, error) { - b, sec, nsec, err := decodeTime(b) - if err != nil { - return b, time.Time{}, err +// slowPeekType is the old implementation of PeekType. It's used to generate +// the lookup table for PeekType. +func slowPeekType(b []byte) Type { + if len(b) >= 1 { + m := b[0] + switch { + case m == encodedNull, m == encodedNullDesc: + return Null + case m == encodedNotNull, m == encodedNotNullDesc: + return NotNull + case m == arrayKeyMarker: + return ArrayKeyAsc + case m == arrayKeyDescendingMarker: + return ArrayKeyDesc + case m == bytesMarker: + return Bytes + case m == bytesDescMarker: + return BytesDesc + case m == bitArrayMarker: + return BitArray + case m == bitArrayDescMarker: + return BitArrayDesc + case m == timeMarker: + return Time + case m == timeTZMarker: + return TimeTZ + case m == geoMarker: + return Geo + case m == box2DMarker: + return Box2D + case m == geoDescMarker: + return GeoDesc + case m == byte(Array): + return Array + case m == byte(True): + return True + case m == byte(False): + return False + case m == durationBigNegMarker, m == durationMarker, m == durationBigPosMarker: + return Duration + case m >= IntMin && m <= IntMax: + return Int + case m >= floatNaN && m <= floatNaNDesc: + return Float + case m >= decimalNaN && m <= decimalNaNDesc: + return Decimal + } } - return b, time.Unix(sec, nsec).UTC(), nil + return Unknown } -// DecodeTimeDescending is the descending version of DecodeTimeAscending. -func DecodeTimeDescending(b []byte) ([]byte, time.Time, error) { - b, sec, nsec, err := decodeTime(b) - if err != nil { - return b, time.Time{}, err +// GetMultiVarintLen find the length of encoded varints that follow a +// 1-byte tag. +func GetMultiVarintLen(b []byte, num int) (int, error) { + p := 1 + for i := 0; i < num && p < len(b); i++ { + len, err := getVarintLen(b[p:]) + if err != nil { + return 0, err + } + p += len } - return b, time.Unix(^sec, ^nsec).UTC(), nil -} - -func decodeTime(b []byte) (r []byte, sec int64, nsec int64, err error) { - if PeekType(b) != Time { - return nil, 0, 0, errors.Errorf("did not find marker") - } - b = b[1:] - b, sec, err = DecodeVarintAscending(b) - if err != nil { - return b, 0, 0, err - } - b, nsec, err = DecodeVarintAscending(b) - if err != nil { - return b, 0, 0, err - } - return b, sec, nsec, nil -} - -// EncodeBox2DAscending encodes a bounding box in ascending order. -func EncodeBox2DAscending(b []byte, box geo.CartesianBoundingBox) ([]byte, error) { - b = append(b, box2DMarker) - b = EncodeFloatAscending(b, box.LoX) - b = EncodeFloatAscending(b, box.HiX) - b = EncodeFloatAscending(b, box.LoY) - b = EncodeFloatAscending(b, box.HiY) - return b, nil -} - -// EncodeBox2DDescending encodes a bounding box in descending order. -func EncodeBox2DDescending(b []byte, box geo.CartesianBoundingBox) ([]byte, error) { - b = append(b, box2DMarker) - b = EncodeFloatDescending(b, box.LoX) - b = EncodeFloatDescending(b, box.HiX) - b = EncodeFloatDescending(b, box.LoY) - b = EncodeFloatDescending(b, box.HiY) - return b, nil -} - -// DecodeBox2DAscending decodes a box2D object in ascending order. -func DecodeBox2DAscending(b []byte) ([]byte, geo.CartesianBoundingBox, error) { - box := geo.CartesianBoundingBox{} - if PeekType(b) != Box2D { - return nil, box, errors.Errorf("did not find Box2D marker") - } - - b = b[1:] - var err error - b, box.LoX, err = DecodeFloatAscending(b) - if err != nil { - return nil, box, err - } - b, box.HiX, err = DecodeFloatAscending(b) - if err != nil { - return nil, box, err - } - b, box.LoY, err = DecodeFloatAscending(b) - if err != nil { - return nil, box, err - } - b, box.HiY, err = DecodeFloatAscending(b) - if err != nil { - return nil, box, err - } - return b, box, nil -} - -// DecodeBox2DDescending decodes a box2D object in descending order. -func DecodeBox2DDescending(b []byte) ([]byte, geo.CartesianBoundingBox, error) { - box := geo.CartesianBoundingBox{} - if PeekType(b) != Box2D { - return nil, box, errors.Errorf("did not find Box2D marker") - } - - b = b[1:] - var err error - b, box.LoX, err = DecodeFloatDescending(b) - if err != nil { - return nil, box, err - } - b, box.HiX, err = DecodeFloatDescending(b) - if err != nil { - return nil, box, err - } - b, box.LoY, err = DecodeFloatDescending(b) - if err != nil { - return nil, box, err - } - b, box.HiY, err = DecodeFloatDescending(b) - if err != nil { - return nil, box, err - } - return b, box, nil -} - -// EncodeGeoAscending encodes a geopb.SpatialObject value in ascending order and -// returns the new buffer. -// It is sorted by the given curve index, followed by the bytes of the spatial object. -func EncodeGeoAscending(b []byte, curveIndex uint64, so *geopb.SpatialObject) ([]byte, error) { - b = append(b, geoMarker) - b = EncodeUint64Ascending(b, curveIndex) - - data, err := protoutil.Marshal(so) - if err != nil { - return nil, err - } - b = encodeBytesAscendingWithTerminator(b, data, ascendingGeoEscapes.escapedTerm) - return b, nil -} - -// EncodeGeoDescending encodes a geopb.SpatialObject value in descending order and -// returns the new buffer. -// It is sorted by the given curve index, followed by the bytes of the spatial object. -func EncodeGeoDescending(b []byte, curveIndex uint64, so *geopb.SpatialObject) ([]byte, error) { - b = append(b, geoDescMarker) - b = EncodeUint64Descending(b, curveIndex) - - data, err := protoutil.Marshal(so) - if err != nil { - return nil, err - } - n := len(b) - b = encodeBytesAscendingWithTerminator(b, data, ascendingGeoEscapes.escapedTerm) - if err != nil { - return nil, err - } - onesComplement(b[n:]) - return b, nil -} - -// DecodeGeoAscending decodes a geopb.SpatialObject value that was encoded -// in ascending order back into a geopb.SpatialObject. The so parameter -// must already be empty/reset. -func DecodeGeoAscending(b []byte, so *geopb.SpatialObject) ([]byte, error) { - if PeekType(b) != Geo { - return nil, errors.Errorf("did not find Geo marker") - } - b = b[1:] - var err error - b, _, err = DecodeUint64Ascending(b) - if err != nil { - return nil, err - } - - var pbBytes []byte - b, pbBytes, err = decodeBytesInternal(b, pbBytes, ascendingGeoEscapes, false /* expectMarker */) - if err != nil { - return b, err - } - // Not using protoutil.Unmarshal since the call to so.Reset() will waste the - // pre-allocated EWKB. - err = so.Unmarshal(pbBytes) - return b, err -} - -// DecodeGeoDescending decodes a geopb.SpatialObject value that was encoded -// in descending order back into a geopb.SpatialObject. The so parameter -// must already be empty/reset. -func DecodeGeoDescending(b []byte, so *geopb.SpatialObject) ([]byte, error) { - if PeekType(b) != GeoDesc { - return nil, errors.Errorf("did not find Geo marker") - } - b = b[1:] - var err error - b, _, err = DecodeUint64Descending(b) - if err != nil { - return nil, err - } - - var pbBytes []byte - b, pbBytes, err = decodeBytesInternal(b, pbBytes, descendingGeoEscapes, false /* expectMarker */) - if err != nil { - return b, err - } - onesComplement(pbBytes) - // Not using protoutil.Unmarshal since the call to so.Reset() will waste the - // pre-allocated EWKB. - err = so.Unmarshal(pbBytes) - return b, err -} - -// EncodeTimeTZAscending encodes a timetz.TimeTZ value and appends it to -// the supplied buffer and returns the final buffer. -// The encoding is guaranteed to be ordered such that if t1.Before(t2) -// then after encodeTimeTZ(b1, t1) and encodeTimeTZ(b2, t2), -// Compare(b1, b2) < 0. -// The time zone offset is included in the encoding. -func EncodeTimeTZAscending(b []byte, t timetz.TimeTZ) []byte { - // Do not use TimeOfDay's add function, as it loses 24:00:00 encoding. - return encodeTimeTZ(b, int64(t.TimeOfDay)+int64(t.OffsetSecs)*offsetSecsToMicros, t.OffsetSecs) -} - -// EncodeTimeTZDescending is the descending version of EncodeTimeTZAscending. -func EncodeTimeTZDescending(b []byte, t timetz.TimeTZ) []byte { - // Do not use TimeOfDay's add function, as it loses 24:00:00 encoding. - return encodeTimeTZ(b, ^(int64(t.TimeOfDay) + int64(t.OffsetSecs)*offsetSecsToMicros), ^t.OffsetSecs) -} - -func encodeTimeTZ(b []byte, unixMicros int64, offsetSecs int32) []byte { - b = append(b, timeTZMarker) - b = EncodeVarintAscending(b, unixMicros) - b = EncodeVarintAscending(b, int64(offsetSecs)) - return b -} - -// DecodeTimeTZAscending decodes a timetz.TimeTZ value which was encoded -// using encodeTimeTZ. The remainder of the input buffer and the decoded -// timetz.TimeTZ are returned. -func DecodeTimeTZAscending(b []byte) ([]byte, timetz.TimeTZ, error) { - b, unixMicros, offsetSecs, err := decodeTimeTZ(b) - if err != nil { - return nil, timetz.TimeTZ{}, err - } - // Do not use timeofday.FromInt, as it loses 24:00:00 encoding. - return b, timetz.TimeTZ{ - TimeOfDay: timeofday.TimeOfDay(unixMicros - int64(offsetSecs)*offsetSecsToMicros), - OffsetSecs: offsetSecs, - }, nil -} - -// DecodeTimeTZDescending is the descending version of DecodeTimeTZAscending. -func DecodeTimeTZDescending(b []byte) ([]byte, timetz.TimeTZ, error) { - b, unixMicros, offsetSecs, err := decodeTimeTZ(b) - if err != nil { - return nil, timetz.TimeTZ{}, err - } - // Do not use timeofday.FromInt, as it loses 24:00:00 encoding. - return b, timetz.TimeTZ{ - TimeOfDay: timeofday.TimeOfDay(^unixMicros - int64(^offsetSecs)*offsetSecsToMicros), - OffsetSecs: ^offsetSecs, - }, nil -} - -func decodeTimeTZ(b []byte) ([]byte, int64, int32, error) { - if PeekType(b) != TimeTZ { - return nil, 0, 0, errors.Errorf("did not find marker") - } - b = b[1:] - var err error - var unixMicros int64 - b, unixMicros, err = DecodeVarintAscending(b) - if err != nil { - return nil, 0, 0, err - } - var offsetSecs int64 - b, offsetSecs, err = DecodeVarintAscending(b) - if err != nil { - return nil, 0, 0, err - } - return b, unixMicros, int32(offsetSecs), nil -} - -// EncodeDurationAscending encodes a duration.Duration value, appends it to the -// supplied buffer, and returns the final buffer. The encoding is guaranteed to -// be ordered such that if t1.Compare(t2) < 0 (or = 0 or > 0) then bytes.Compare -// will order them the same way after encoding. -func EncodeDurationAscending(b []byte, d duration.Duration) ([]byte, error) { - sortNanos, months, days, err := d.Encode() - if err != nil { - // TODO(dan): Handle this using d.EncodeBigInt() and the - // durationBigNeg/durationBigPos markers. - return b, err - } - b = append(b, durationMarker) - b = EncodeVarintAscending(b, sortNanos) - b = EncodeVarintAscending(b, months) - b = EncodeVarintAscending(b, days) - return b, nil -} - -// EncodeDurationDescending is the descending version of EncodeDurationAscending. -func EncodeDurationDescending(b []byte, d duration.Duration) ([]byte, error) { - sortNanos, months, days, err := d.Encode() - if err != nil { - // TODO(dan): Handle this using d.EncodeBigInt() and the - // durationBigNeg/durationBigPos markers. - return b, err - } - b = append(b, durationMarker) - b = EncodeVarintDescending(b, sortNanos) - b = EncodeVarintDescending(b, months) - b = EncodeVarintDescending(b, days) - return b, nil -} - -// DecodeDurationAscending decodes a duration.Duration value which was encoded -// using EncodeDurationAscending. The remainder of the input buffer and the -// decoded duration.Duration are returned. -func DecodeDurationAscending(b []byte) ([]byte, duration.Duration, error) { - if PeekType(b) != Duration { - return nil, duration.Duration{}, errors.Errorf("did not find marker %x", b) - } - b = b[1:] - b, sortNanos, err := DecodeVarintAscending(b) - if err != nil { - return b, duration.Duration{}, err - } - b, months, err := DecodeVarintAscending(b) - if err != nil { - return b, duration.Duration{}, err - } - b, days, err := DecodeVarintAscending(b) - if err != nil { - return b, duration.Duration{}, err - } - d, err := duration.Decode(sortNanos, months, days) - if err != nil { - return b, duration.Duration{}, err - } - return b, d, nil -} - -// DecodeDurationDescending is the descending version of DecodeDurationAscending. -func DecodeDurationDescending(b []byte) ([]byte, duration.Duration, error) { - if PeekType(b) != Duration { - return nil, duration.Duration{}, errors.Errorf("did not find marker") - } - b = b[1:] - b, sortNanos, err := DecodeVarintDescending(b) - if err != nil { - return b, duration.Duration{}, err - } - b, months, err := DecodeVarintDescending(b) - if err != nil { - return b, duration.Duration{}, err - } - b, days, err := DecodeVarintDescending(b) - if err != nil { - return b, duration.Duration{}, err - } - d, err := duration.Decode(sortNanos, months, days) - if err != nil { - return b, duration.Duration{}, err - } - return b, d, nil -} - -// EncodeBitArrayAscending encodes a bitarray.BitArray value, appends it to the -// supplied buffer, and returns the final buffer. The encoding is guaranteed to -// be ordered such that if t1.Compare(t2) < 0 (or = 0 or > 0) then bytes.Compare -// will order them the same way after encoding. -// -// The encoding uses varint encoding for each word of the backing -// array. This is a trade-off. The alternative is to encode the entire -// backing word array as a byte array, using byte array encoding and escaped -// special bytes (via `encodeBytesAscendingWithoutTerminatorOrPrefix`). -// There are two arguments against this alternative: -// - the bytes must be encoded big endian, but the most common architectures -// running CockroachDB are little-endian, so the bytes would need -// to be reordered prior to encoding. -// - when decoding or skipping over a value, the decoding/sizing loop -// would need to look at every byte of the encoding to find the -// terminator. -// In contrast, the chosen encoding using varints is endianness-agnostic -// and enables fast decoding/skipping thanks ot the tag bytes. -func EncodeBitArrayAscending(b []byte, d utils.BitArray) []byte { - b = append(b, bitArrayMarker) - words, lastBitsUsed := d.EncodingParts() - for _, w := range words { - b = EncodeUvarintAscending(b, w) - } - b = append(b, bitArrayDataTerminator) - b = EncodeUvarintAscending(b, lastBitsUsed) - return b -} - -// EncodeBitArrayDescending is the descending version of EncodeBitArrayAscending. -func EncodeBitArrayDescending(b []byte, d utils.BitArray) []byte { - b = append(b, bitArrayDescMarker) - words, lastBitsUsed := d.EncodingParts() - for _, w := range words { - b = EncodeUvarintDescending(b, w) - } - b = append(b, bitArrayDataDescTerminator) - b = EncodeUvarintDescending(b, lastBitsUsed) - return b -} - -// DecodeBitArrayAscending decodes a bit array which was encoded using -// EncodeBitArrayAscending. The remainder of the input buffer and the -// decoded bit array are returned. -func DecodeBitArrayAscending(b []byte) ([]byte, utils.BitArray, error) { - if PeekType(b) != BitArray { - return nil, utils.BitArray{}, errors.Errorf("did not find marker %x", b) - } - b = b[1:] - - // First compute the length. - numWords, _, err := getBitArrayWordsLen(b, bitArrayDataTerminator) - if err != nil { - return b, utils.BitArray{}, err - } - // Decode the words. - words := make([]uint64, numWords) - for i := range words { - b, words[i], err = DecodeUvarintAscending(b) - if err != nil { - return b, utils.BitArray{}, err - } - } - // Decode the final part. - if len(b) == 0 || b[0] != bitArrayDataTerminator { - return b, utils.BitArray{}, errBitArrayTerminatorMissing - } - b = b[1:] - b, lastVal, err := DecodeUvarintAscending(b) - if err != nil { - return b, utils.BitArray{}, err - } - ba, err := utils.FromEncodingParts(words, lastVal) - return b, ba, err -} - -var errBitArrayTerminatorMissing = errors.New("cannot find bit array data terminator") - -// getBitArrayWordsLen returns the number of bit array words in the -// encoded bytes and the size in bytes of the encoded word array -// (excluding the terminator byte). -func getBitArrayWordsLen(b []byte, term byte) (int, int, error) { - bSearch := b - numWords := 0 - sz := 0 - for { - if len(bSearch) == 0 { - return 0, 0, errors.Errorf("slice too short for bit array (%d)", len(b)) - } - if bSearch[0] == term { - break - } - vLen, err := getVarintLen(bSearch) - if err != nil { - return 0, 0, err - } - bSearch = bSearch[vLen:] - numWords++ - sz += vLen - } - return numWords, sz, nil -} - -// DecodeBitArrayDescending is the descending version of DecodeBitArrayAscending. -func DecodeBitArrayDescending(b []byte) ([]byte, utils.BitArray, error) { - if PeekType(b) != BitArrayDesc { - return nil, utils.BitArray{}, errors.Errorf("did not find marker %x", b) - } - b = b[1:] - - // First compute the length. - numWords, _, err := getBitArrayWordsLen(b, bitArrayDataDescTerminator) - if err != nil { - return b, utils.BitArray{}, err - } - // Decode the words. - words := make([]uint64, numWords) - for i := range words { - b, words[i], err = DecodeUvarintDescending(b) - if err != nil { - return b, utils.BitArray{}, err - } - } - // Decode the final part. - if len(b) == 0 || b[0] != bitArrayDataDescTerminator { - return b, utils.BitArray{}, errBitArrayTerminatorMissing - } - b = b[1:] - b, lastVal, err := DecodeUvarintDescending(b) - if err != nil { - return b, utils.BitArray{}, err - } - ba, err := utils.FromEncodingParts(words, lastVal) - return b, ba, err -} - -// Type represents the type of a value encoded by -// Encode{Null,NotNull,Varint,Uvarint,Float,Bytes}. -//go:generate stringer -type=Type -type Type int - -// Type values. -// TODO(dan, arjun): Make this into a proto enum. -// The 'Type' annotations are necessary for producing stringer-generated values. -const ( - Unknown Type = 0 - Null Type = 1 - NotNull Type = 2 - Int Type = 3 - Float Type = 4 - Decimal Type = 5 - Bytes Type = 6 - BytesDesc Type = 7 // Bytes encoded descendingly - Time Type = 8 - Duration Type = 9 - True Type = 10 - False Type = 11 - UUID Type = 12 - Array Type = 13 - IPAddr Type = 14 - // SentinelType is used for bit manipulation to check if the encoded type - // value requires more than 4 bits, and thus will be encoded in two bytes. It - // is not used as a type value, and thus intentionally overlaps with the - // subsequent type value. The 'Type' annotation is intentionally omitted here. - SentinelType = 15 - JSON Type = 15 - Tuple Type = 16 - BitArray Type = 17 - BitArrayDesc Type = 18 // BitArray encoded descendingly - TimeTZ Type = 19 - Geo Type = 20 - GeoDesc Type = 21 - ArrayKeyAsc Type = 22 // Array key encoding - ArrayKeyDesc Type = 23 // Array key encoded descendingly - Box2D Type = 24 -) - -// typMap maps an encoded type byte to a decoded Type. It's got 256 slots, one -// for every possible byte value. -var typMap [256]Type - -func init() { - buf := []byte{0} - for i := range typMap { - buf[0] = byte(i) - typMap[i] = slowPeekType(buf) - } -} - -// PeekType peeks at the type of the value encoded at the start of b. -func PeekType(b []byte) Type { - if len(b) >= 1 { - return typMap[b[0]] - } - return Unknown -} - -// slowPeekType is the old implementation of PeekType. It's used to generate -// the lookup table for PeekType. -func slowPeekType(b []byte) Type { - if len(b) >= 1 { - m := b[0] - switch { - case m == encodedNull, m == encodedNullDesc: - return Null - case m == encodedNotNull, m == encodedNotNullDesc: - return NotNull - case m == arrayKeyMarker: - return ArrayKeyAsc - case m == arrayKeyDescendingMarker: - return ArrayKeyDesc - case m == bytesMarker: - return Bytes - case m == bytesDescMarker: - return BytesDesc - case m == bitArrayMarker: - return BitArray - case m == bitArrayDescMarker: - return BitArrayDesc - case m == timeMarker: - return Time - case m == timeTZMarker: - return TimeTZ - case m == geoMarker: - return Geo - case m == box2DMarker: - return Box2D - case m == geoDescMarker: - return GeoDesc - case m == byte(Array): - return Array - case m == byte(True): - return True - case m == byte(False): - return False - case m == durationBigNegMarker, m == durationMarker, m == durationBigPosMarker: - return Duration - case m >= IntMin && m <= IntMax: - return Int - case m >= floatNaN && m <= floatNaNDesc: - return Float - case m >= decimalNaN && m <= decimalNaNDesc: - return Decimal - } - } - return Unknown -} - -// GetMultiVarintLen find the length of encoded varints that follow a -// 1-byte tag. -func GetMultiVarintLen(b []byte, num int) (int, error) { - p := 1 - for i := 0; i < num && p < len(b); i++ { - len, err := getVarintLen(b[p:]) - if err != nil { - return 0, err - } - p += len - } - return p, nil + return p, nil } // getMultiNonsortingVarintLen finds the length of encoded nonsorting varints. @@ -1727,590 +843,46 @@ func PeekLength(b []byte) (int, error) { if m >= decimalNaN && m <= decimalNaNDesc { return getDecimalLen(b) } - return 0, errors.Errorf("unknown tag %d", m) -} - -// PrettyPrintValue returns the string representation of all contiguous -// decodable values in the provided byte slice, separated by a provided -// separator. -// The directions each value is encoded may be provided. If valDirs is nil, -// all values are decoded and printed with the default direction (ascending). -func PrettyPrintValue(valDirs []Direction, b []byte, sep string) string { - s1, allDecoded := prettyPrintValueImpl(valDirs, b, sep) - if allDecoded { - return s1 - } - if undoPrefixEnd, ok := UndoPrefixEnd(b); ok { - // When we UndoPrefixEnd, we may have lost a tail of 0xFFs. Try to add - // enough of them to get something decoded. This is best-effort, we have to stop - // somewhere. - cap := 20 - if len(valDirs) > len(b) { - cap = len(valDirs) - len(b) - } - for i := 0; i < cap; i++ { - if s2, allDecoded := prettyPrintValueImpl(valDirs, undoPrefixEnd, sep); allDecoded { - return s2 + sep + "PrefixEnd" - } - undoPrefixEnd = append(undoPrefixEnd, 0xFF) - } - } - return s1 -} - -func prettyPrintValueImpl(valDirs []Direction, b []byte, sep string) (string, bool) { - allDecoded := true - var buf strings.Builder - for len(b) > 0 { - // If there are more values than encoding directions specified, - // valDir will contain the 0 value of Direction. - // prettyPrintFirstValue will then use the default encoding - // direction per each value type. - var valDir Direction - if len(valDirs) > 0 { - valDir = valDirs[0] - valDirs = valDirs[1:] - } - - bb, s, err := prettyPrintFirstValue(valDir, b) - if err != nil { - allDecoded = false - buf.WriteString(sep) - buf.WriteByte('?') - buf.WriteByte('?') - buf.WriteByte('?') - } else { - buf.WriteString(sep) - buf.WriteString(s) - } - b = bb - } - return buf.String(), allDecoded -} - -// prettyPrintFirstValue returns a string representation of the first decodable -// value in the provided byte slice, along with the remaining byte slice -// after decoding. -// -// Ascending will be the default direction (when dir is the 0 value) for all -// values except for NotNull. -// -// NotNull: if Ascending or Descending directions are explicitly provided (i.e. -// for table keys), then !NULL will be used. Otherwise, # will be used. -// -// We prove that the default # will only be used for interleaved sentinels: -// - For non-table keys, we never have NotNull. -// - For table keys, we always explicitly pass in Ascending and Descending for -// all key values, including NotNulls. The only case we do not pass in -// direction is during a SHOW RANGES ON TABLE parent and there exists -// an interleaved split key. Note that interleaved keys cannot have NotNull -// values except for the interleaved sentinel. -// -// Defaulting to Ascending for all other value types is fine since all -// non-table keys encode values with Ascending. -// -// The only case where we end up defaulting direction for table keys is for -// interleaved split keys in SHOW RANGES ON TABLE parent. Since -// interleaved prefixes are defined on the primary key (and primary key values -// are always encoded Ascending), this will always print out the correct key -// even if we don't have directions for the child index's columns. -func prettyPrintFirstValue(dir Direction, b []byte) ([]byte, string, error) { - var err error - switch typ := PeekType(b); typ { - case Null: - b, _ = DecodeIfNull(b) - return b, "NULL", nil - case True: - return b[1:], "True", nil - case False: - return b[1:], "False", nil - case Array: - return b[1:], "Arr", nil - case ArrayKeyAsc, ArrayKeyDesc: - encDir := Ascending - if typ == ArrayKeyDesc { - encDir = Descending - } - var build strings.Builder - buf, err := ValidateAndConsumeArrayKeyMarker(b, encDir) - if err != nil { - return nil, "", err - } - build.WriteString("ARRAY[") - first := true - // Use the array key decoding logic, but instead of calling out - // to DecodeTableKey, just make a recursive call. - for { - if len(buf) == 0 { - return nil, "", errors.AssertionFailedf("invalid array (unterminated)") - } - if IsArrayKeyDone(buf, encDir) { - buf = buf[1:] - break - } - var next string - if IsNextByteArrayEncodedNull(buf, dir) { - next = "NULL" - buf = buf[1:] - } else { - buf, next, err = prettyPrintFirstValue(dir, buf) - if err != nil { - return nil, "", err - } - } - if !first { - build.WriteString(",") - } - build.WriteString(next) - first = false - } - build.WriteString("]") - return buf, build.String(), nil - case NotNull: - // The tag can be either encodedNotNull or encodedNotNullDesc. The - // latter can be an interleaved sentinel. - isNotNullDesc := (b[0] == encodedNotNullDesc) - b, _ = DecodeIfNotNull(b) - if dir != Ascending && dir != Descending && isNotNullDesc { - // Unspecified direction (0 value) will default to '#' for the - // interleaved sentinel. - return b, "#", nil - } - return b, "!NULL", nil - case Int: - var i int64 - if dir == Descending { - b, i, err = DecodeVarintDescending(b) - } else { - b, i, err = DecodeVarintAscending(b) - } - if err != nil { - return b, "", err - } - return b, strconv.FormatInt(i, 10), nil - case Float: - var f float64 - if dir == Descending { - b, f, err = DecodeFloatDescending(b) - } else { - b, f, err = DecodeFloatAscending(b) - } - if err != nil { - return b, "", err - } - return b, strconv.FormatFloat(f, 'g', -1, 64), nil - case Decimal: - var d apd.Decimal - if dir == Descending { - b, d, err = DecodeDecimalDescending(b, nil) - } else { - b, d, err = DecodeDecimalAscending(b, nil) - } - if err != nil { - return b, "", err - } - return b, d.String(), nil - case BitArray: - if dir == Descending { - return b, "", errors.Errorf("descending bit column dir but ascending bit array encoding") - } - var d utils.BitArray - b, d, err = DecodeBitArrayAscending(b) - return b, "B" + d.String(), err - case BitArrayDesc: - if dir == Ascending { - return b, "", errors.Errorf("ascending bit column dir but descending bit array encoding") - } - var d utils.BitArray - b, d, err = DecodeBitArrayDescending(b) - return b, "B" + d.String(), err - case Bytes: - if dir == Descending { - return b, "", errors.Errorf("descending bytes column dir but ascending bytes encoding") - } - var s string - b, s, err = DecodeUnsafeStringAscending(b, nil) - if err != nil { - return b, "", err - } - return b, strconv.Quote(s), nil - case BytesDesc: - if dir == Ascending { - return b, "", errors.Errorf("ascending bytes column dir but descending bytes encoding") - } - - var s string - b, s, err = DecodeUnsafeStringDescending(b, nil) - if err != nil { - return b, "", err - } - return b, strconv.Quote(s), nil - case Time: - var t time.Time - if dir == Descending { - b, t, err = DecodeTimeDescending(b) - } else { - b, t, err = DecodeTimeAscending(b) - } - if err != nil { - return b, "", err - } - return b, t.UTC().Format(time.RFC3339Nano), nil - case TimeTZ: - var t timetz.TimeTZ - if dir == Descending { - b, t, err = DecodeTimeTZDescending(b) - } else { - b, t, err = DecodeTimeTZAscending(b) - } - if err != nil { - return b, "", err - } - return b, t.String(), nil - case Duration: - var d duration.Duration - if dir == Descending { - b, d, err = DecodeDurationDescending(b) - } else { - b, d, err = DecodeDurationAscending(b) - } - if err != nil { - return b, "", err - } - return b, d.StringNanos(), nil - default: - if len(b) >= 1 { - switch b[0] { - case jsonInvertedIndex: - var str string - str, b, err = prettyPrintInvertedIndexKey(b) - if err != nil { - return b, "", err - } - if str == "" { - return prettyPrintFirstValue(dir, b) - } - return b, str, nil - case jsonEmptyArray: - return b[1:], "[]", nil - case jsonEmptyObject: - return b[1:], "{}", nil - } - } - // This shouldn't ever happen, but if it does, return an empty slice. - return nil, strconv.Quote(string(b)), nil - } -} - -// UndoPrefixEnd is a partial inverse for roachpb.Key.PrefixEnd. -// -// In general, we can't undo PrefixEnd because it is lossy; we don't know how -// many FFs were stripped from the original key. For example: -// - key: 01 02 03 FF FF -// - PrefixEnd: 01 02 04 -// - UndoPrefixEnd: 01 02 03 -// -// Some keys are not possible results of PrefixEnd; in particular, PrefixEnd -// keys never end in 00. If an impossible key is passed, the second return value -// is false. -// -// Specifically, calling UndoPrefixEnd will reverse the effects of calling a -// PrefixEnd on a byte sequence, except when the byte sequence represents a -// maximal prefix (i.e., 0xff...). This is because PrefixEnd is a lossy -// operation: PrefixEnd(0xff) returns 0xff rather than wrapping around to the -// minimal prefix 0x00. For consistency, UndoPrefixEnd is also lossy: -// UndoPrefixEnd(0x00) returns 0x00 rather than wrapping around to the maximal -// prefix 0xff. -// -// Formally: -// -// PrefixEnd(UndoPrefixEnd(p)) = p for all non-minimal prefixes p -// UndoPrefixEnd(PrefixEnd(p)) = p for all non-maximal prefixes p -// -// A minimal prefix is any prefix that consists only of one or more 0x00 bytes; -// analogously, a maximal prefix is any prefix that consists only of one or more -// 0xff bytes. -// -// UndoPrefixEnd is implemented here to avoid a circular dependency on roachpb, -// but arguably belongs in a byte-manipulation utility package. -func UndoPrefixEnd(b []byte) (_ []byte, ok bool) { - if len(b) == 0 || b[len(b)-1] == 0 { - // Not a possible result of PrefixEnd. - return nil, false - } - out := append([]byte(nil), b...) - out[len(out)-1]-- - return out, true -} - -// MaxNonsortingVarintLen is the maximum length of an EncodeNonsortingVarint -// encoded value. -const MaxNonsortingVarintLen = binary.MaxVarintLen64 - -// EncodeNonsortingStdlibVarint encodes an int value using encoding/binary, appends it -// to the supplied buffer, and returns the final buffer. -func EncodeNonsortingStdlibVarint(appendTo []byte, x int64) []byte { - // Fixed size array to allocate this on the stack. - var scratch [binary.MaxVarintLen64]byte - i := binary.PutVarint(scratch[:binary.MaxVarintLen64], x) - return append(appendTo, scratch[:i]...) -} - -// DecodeNonsortingStdlibVarint decodes a value encoded by EncodeNonsortingVarint. It -// returns the length of the encoded varint and value. -func DecodeNonsortingStdlibVarint(b []byte) (remaining []byte, length int, value int64, err error) { - value, length = binary.Varint(b) - if length <= 0 { - return nil, 0, 0, fmt.Errorf("int64 varint decoding failed: %d", length) - } - return b[length:], length, value, nil -} - -// MaxNonsortingUvarintLen is the maximum length of an EncodeNonsortingUvarint -// encoded value. -const MaxNonsortingUvarintLen = 10 - -// EncodeNonsortingUvarint encodes a uint64, appends it to the supplied buffer, -// and returns the final buffer. The encoding used is similar to -// encoding/binary, but with the most significant bits first: -// - Unsigned integers are serialized 7 bits at a time, starting with the -// most significant bits. -// - The most significant bit (msb) in each output byte indicates if there -// is a continuation byte (msb = 1). -func EncodeNonsortingUvarint(appendTo []byte, x uint64) []byte { - switch { - case x < (1 << 7): - return append(appendTo, byte(x)) - case x < (1 << 14): - return append(appendTo, 0x80|byte(x>>7), 0x7f&byte(x)) - case x < (1 << 21): - return append(appendTo, 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - case x < (1 << 28): - return append(appendTo, 0x80|byte(x>>21), 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - case x < (1 << 35): - return append(appendTo, 0x80|byte(x>>28), 0x80|byte(x>>21), 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - case x < (1 << 42): - return append(appendTo, 0x80|byte(x>>35), 0x80|byte(x>>28), 0x80|byte(x>>21), 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - case x < (1 << 49): - return append(appendTo, 0x80|byte(x>>42), 0x80|byte(x>>35), 0x80|byte(x>>28), 0x80|byte(x>>21), 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - case x < (1 << 56): - return append(appendTo, 0x80|byte(x>>49), 0x80|byte(x>>42), 0x80|byte(x>>35), 0x80|byte(x>>28), 0x80|byte(x>>21), 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - case x < (1 << 63): - return append(appendTo, 0x80|byte(x>>56), 0x80|byte(x>>49), 0x80|byte(x>>42), 0x80|byte(x>>35), 0x80|byte(x>>28), 0x80|byte(x>>21), 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - default: - return append(appendTo, 0x80|byte(x>>63), 0x80|byte(x>>56), 0x80|byte(x>>49), 0x80|byte(x>>42), 0x80|byte(x>>35), 0x80|byte(x>>28), 0x80|byte(x>>21), 0x80|byte(x>>14), 0x80|byte(x>>7), 0x7f&byte(x)) - } -} - -// DecodeNonsortingUvarint decodes a value encoded by EncodeNonsortingUvarint. It -// returns the length of the encoded varint and value. -func DecodeNonsortingUvarint(buf []byte) (remaining []byte, length int, value uint64, err error) { - // TODO(dan): Handle overflow. - for i, b := range buf { - value += uint64(b & 0x7f) - if b < 0x80 { - return buf[i+1:], i + 1, value, nil - } - value <<= 7 - } - return buf, 0, 0, nil -} - -// DecodeNonsortingStdlibUvarint decodes a value encoded with binary.PutUvarint. It -// returns the length of the encoded varint and value. -func DecodeNonsortingStdlibUvarint( - buf []byte, -) (remaining []byte, length int, value uint64, err error) { - i, n := binary.Uvarint(buf) - if n <= 0 { - return buf, 0, 0, errors.New("buffer too small") - } - return buf[n:], n, i, nil -} - -// PeekLengthNonsortingUvarint returns the length of the value that starts at -// the beginning of buf and was encoded by EncodeNonsortingUvarint. -func PeekLengthNonsortingUvarint(buf []byte) int { - for i, b := range buf { - if b&0x80 == 0 { - return i + 1 - } - } - return 0 -} - -// NoColumnID is a sentinel for the EncodeFooValue methods representing an -// invalid column id. -const NoColumnID uint32 = 0 - -// EncodeValueTag encodes the prefix that is used by each of the EncodeFooValue -// methods. -// -// The prefix uses varints to encode a column id and type, packing them into a -// single byte when they're small (colID < 8 and typ < 15). This works by -// shifting the colID "left" by 4 and putting any type less than 15 in the low -// bytes. The result is uvarint encoded and fits in one byte if the original -// column id fit in 3 bits. If it doesn't fit in one byte, the most significant -// bits spill to the "left", leaving the type bits always at the very "right". -// -// If the type is > 15, the reserved sentinel of 15 is placed in the type bits -// and a uvarint follows with the type value. This means that there are always -// one or two uvarints. -// -// Together, this means the everything but the last byte of the first uvarint -// can be dropped if the column id isn't needed. -func EncodeValueTag(appendTo []byte, colID uint32, typ Type) []byte { - if typ >= SentinelType { - appendTo = EncodeNonsortingUvarint(appendTo, uint64(colID)<<4|uint64(SentinelType)) - return EncodeNonsortingUvarint(appendTo, uint64(typ)) - } - if colID == NoColumnID { - // TODO(dan): EncodeValueTag is not inlined by the compiler. Copying this - // special case into one of the EncodeFooValue functions speeds it up by - // ~4ns. - return append(appendTo, byte(typ)) - } - return EncodeNonsortingUvarint(appendTo, uint64(colID)<<4|uint64(typ)) -} - -// EncodeNullValue encodes a null value, appends it to the supplied buffer, and -// returns the final buffer. -func EncodeNullValue(appendTo []byte, colID uint32) []byte { - return EncodeValueTag(appendTo, colID, Null) -} - -// EncodeNotNullValue encodes a not null value, appends it to the supplied -// buffer, and returns the final buffer. -func EncodeNotNullValue(appendTo []byte, colID uint32) []byte { - return EncodeValueTag(appendTo, colID, NotNull) -} - -// EncodeBoolValue encodes a bool value, appends it to the supplied buffer, and -// returns the final buffer. -func EncodeBoolValue(appendTo []byte, colID uint32, b bool) []byte { - if b { - return EncodeValueTag(appendTo, colID, True) - } - return EncodeValueTag(appendTo, colID, False) -} - -// EncodeIntValue encodes an int value with its value tag, appends it to the -// supplied buffer, and returns the final buffer. -func EncodeIntValue(appendTo []byte, colID uint32, i int64) []byte { - appendTo = EncodeValueTag(appendTo, colID, Int) - return EncodeUntaggedIntValue(appendTo, i) -} - -// EncodeUntaggedIntValue encodes an int value, appends it to the supplied buffer, and -// returns the final buffer. -func EncodeUntaggedIntValue(appendTo []byte, i int64) []byte { - return EncodeNonsortingStdlibVarint(appendTo, i) -} - -const floatValueEncodedLength = uint64AscendingEncodedLength - -// EncodeFloatValue encodes a float value with its value tag, appends it to the -// supplied buffer, and returns the final buffer. -func EncodeFloatValue(appendTo []byte, colID uint32, f float64) []byte { - appendTo = EncodeValueTag(appendTo, colID, Float) - return EncodeUntaggedFloatValue(appendTo, f) -} - -// EncodeUntaggedFloatValue encodes a float value, appends it to the supplied buffer, -// and returns the final buffer. -func EncodeUntaggedFloatValue(appendTo []byte, f float64) []byte { - return EncodeUint64Ascending(appendTo, math.Float64bits(f)) -} - -// EncodeBytesValue encodes a byte array value with its value tag, appends it to -// the supplied buffer, and returns the final buffer. -func EncodeBytesValue(appendTo []byte, colID uint32, data []byte) []byte { - appendTo = EncodeValueTag(appendTo, colID, Bytes) - return EncodeUntaggedBytesValue(appendTo, data) -} - -// EncodeUntaggedBytesValue encodes a byte array value, appends it to the supplied -// buffer, and returns the final buffer. -func EncodeUntaggedBytesValue(appendTo []byte, data []byte) []byte { - appendTo = EncodeNonsortingUvarint(appendTo, uint64(len(data))) - return append(appendTo, data...) -} - -// EncodeArrayValue encodes a byte array value with its value tag, appends it to -// the supplied buffer, and returns the final buffer. -func EncodeArrayValue(appendTo []byte, colID uint32, data []byte) []byte { - appendTo = EncodeValueTag(appendTo, colID, Array) - return EncodeUntaggedBytesValue(appendTo, data) -} - -// EncodeTimeValue encodes a time.Time value with its value tag, appends it to -// the supplied buffer, and returns the final buffer. -func EncodeTimeValue(appendTo []byte, colID uint32, t time.Time) []byte { - appendTo = EncodeValueTag(appendTo, colID, Time) - return EncodeUntaggedTimeValue(appendTo, t) -} - -// EncodeUntaggedTimeValue encodes a time.Time value, appends it to the supplied buffer, -// and returns the final buffer. -func EncodeUntaggedTimeValue(appendTo []byte, t time.Time) []byte { - appendTo = EncodeNonsortingStdlibVarint(appendTo, t.Unix()) - return EncodeNonsortingStdlibVarint(appendTo, int64(t.Nanosecond())) -} - -// EncodeTimeTZValue encodes a timetz.TimeTZ value with its value tag, appends it to -// the supplied buffer, and returns the final buffer. -func EncodeTimeTZValue(appendTo []byte, colID uint32, t timetz.TimeTZ) []byte { - appendTo = EncodeValueTag(appendTo, colID, TimeTZ) - return EncodeUntaggedTimeTZValue(appendTo, t) -} - -// EncodeUntaggedTimeTZValue encodes a time.Time value, appends it to the supplied buffer, -// and returns the final buffer. -func EncodeUntaggedTimeTZValue(appendTo []byte, t timetz.TimeTZ) []byte { - appendTo = EncodeNonsortingStdlibVarint(appendTo, int64(t.TimeOfDay)) - return EncodeNonsortingStdlibVarint(appendTo, int64(t.OffsetSecs)) -} - -// EncodeBox2DValue encodes a geo.CartesianBoundingBox with its value tag, appends it to -// the supplied buffer and returns the final buffer. -func EncodeBox2DValue(appendTo []byte, colID uint32, b geo.CartesianBoundingBox) ([]byte, error) { - appendTo = EncodeValueTag(appendTo, colID, Box2D) - return EncodeUntaggedBox2DValue(appendTo, b) + return 0, errors.Errorf("unknown tag %d", m) } -// EncodeUntaggedBox2DValue encodes a geo.CartesianBoundingBox value, appends it to the supplied buffer, -// and returns the final buffer. -func EncodeUntaggedBox2DValue(appendTo []byte, b geo.CartesianBoundingBox) ([]byte, error) { - appendTo = EncodeFloatAscending(appendTo, b.LoX) - appendTo = EncodeFloatAscending(appendTo, b.HiX) - appendTo = EncodeFloatAscending(appendTo, b.LoY) - appendTo = EncodeFloatAscending(appendTo, b.HiY) - return appendTo, nil +// DecodeNonsortingStdlibVarint decodes a value encoded by EncodeNonsortingVarint. It +// returns the length of the encoded varint and value. +func DecodeNonsortingStdlibVarint(b []byte) (remaining []byte, length int, value int64, err error) { + value, length = binary.Varint(b) + if length <= 0 { + return nil, 0, 0, fmt.Errorf("int64 varint decoding failed: %d", length) + } + return b[length:], length, value, nil } -// EncodeGeoValue encodes a geopb.SpatialObject value with its value tag, appends it to -// the supplied buffer, and returns the final buffer. -func EncodeGeoValue(appendTo []byte, colID uint32, so *geopb.SpatialObject) ([]byte, error) { - appendTo = EncodeValueTag(appendTo, colID, Geo) - return EncodeUntaggedGeoValue(appendTo, so) +// DecodeNonsortingUvarint decodes a value encoded by EncodeNonsortingUvarint. It +// returns the length of the encoded varint and value. +func DecodeNonsortingUvarint(buf []byte) (remaining []byte, length int, value uint64, err error) { + // TODO(dan): Handle overflow. + for i, b := range buf { + value += uint64(b & 0x7f) + if b < 0x80 { + return buf[i+1:], i + 1, value, nil + } + value <<= 7 + } + return buf, 0, 0, nil } -// EncodeUntaggedGeoValue encodes a geopb.SpatialObject value, appends it to the supplied buffer, -// and returns the final buffer. -func EncodeUntaggedGeoValue(appendTo []byte, so *geopb.SpatialObject) ([]byte, error) { - bytes, err := protoutil.Marshal(so) - if err != nil { - return nil, err +// DecodeNonsortingStdlibUvarint decodes a value encoded with binary.PutUvarint. It +// returns the length of the encoded varint and value. +func DecodeNonsortingStdlibUvarint( + buf []byte, +) (remaining []byte, length int, value uint64, err error) { + i, n := binary.Uvarint(buf) + if n <= 0 { + return buf, 0, 0, errors.New("buffer too small") } - return EncodeUntaggedBytesValue(appendTo, bytes), nil + return buf[n:], n, i, nil } -// EncodeDecimalValue encodes an apd.Decimal value with its value tag, appends -// it to the supplied buffer, and returns the final buffer. -func EncodeDecimalValue(appendTo []byte, colID uint32, d *apd.Decimal) []byte { - appendTo = EncodeValueTag(appendTo, colID, Decimal) - return EncodeUntaggedDecimalValue(appendTo, d) -} +const floatValueEncodedLength = uint64AscendingEncodedLength // EncodeUntaggedDecimalValue encodes an apd.Decimal value, appends it to the supplied // buffer, and returns the final buffer. @@ -2329,75 +901,6 @@ func EncodeUntaggedDecimalValue(appendTo []byte, d *apd.Decimal) []byte { return appendTo[:varintPos+varintLen+decLen] } -// EncodeDurationValue encodes a duration.Duration value with its value tag, -// appends it to the supplied buffer, and returns the final buffer. -func EncodeDurationValue(appendTo []byte, colID uint32, d duration.Duration) []byte { - appendTo = EncodeValueTag(appendTo, colID, Duration) - return EncodeUntaggedDurationValue(appendTo, d) -} - -// EncodeUntaggedDurationValue encodes a duration.Duration value, appends it to the -// supplied buffer, and returns the final buffer. -func EncodeUntaggedDurationValue(appendTo []byte, d duration.Duration) []byte { - appendTo = EncodeNonsortingStdlibVarint(appendTo, d.Months) - appendTo = EncodeNonsortingStdlibVarint(appendTo, d.Days) - return EncodeNonsortingStdlibVarint(appendTo, d.Nanos()) -} - -// EncodeBitArrayValue encodes a bit array value with its value tag, -// appends it to the supplied buffer, and returns the final buffer. -func EncodeBitArrayValue(appendTo []byte, colID uint32, d utils.BitArray) []byte { - appendTo = EncodeValueTag(appendTo, colID, BitArray) - return EncodeUntaggedBitArrayValue(appendTo, d) -} - -// EncodeUntaggedBitArrayValue encodes a bit array value, appends it to the -// supplied buffer, and returns the final buffer. -func EncodeUntaggedBitArrayValue(appendTo []byte, d utils.BitArray) []byte { - bitLen := d.BitLen() - words, _ := d.EncodingParts() - - appendTo = EncodeNonsortingUvarint(appendTo, uint64(bitLen)) - for _, w := range words { - appendTo = EncodeUint64Ascending(appendTo, w) - } - return appendTo -} - -// EncodeUUIDValue encodes a uuid.UUID value with its value tag, appends it to -// the supplied buffer, and returns the final buffer. -func EncodeUUIDValue(appendTo []byte, colID uint32, u uuid.UUID) []byte { - appendTo = EncodeValueTag(appendTo, colID, UUID) - return EncodeUntaggedUUIDValue(appendTo, u) -} - -// EncodeUntaggedUUIDValue encodes a uuid.UUID value, appends it to the supplied buffer, -// and returns the final buffer. -func EncodeUntaggedUUIDValue(appendTo []byte, u uuid.UUID) []byte { - return append(appendTo, u.GetBytes()...) -} - -// EncodeIPAddrValue encodes a ipaddr.IPAddr value with its value tag, appends -// it to the supplied buffer, and returns the final buffer. -func EncodeIPAddrValue(appendTo []byte, colID uint32, u ipaddr.IPAddr) []byte { - appendTo = EncodeValueTag(appendTo, colID, IPAddr) - return EncodeUntaggedIPAddrValue(appendTo, u) -} - -// EncodeUntaggedIPAddrValue encodes a ipaddr.IPAddr value, appends it to the -// supplied buffer, and returns the final buffer. -func EncodeUntaggedIPAddrValue(appendTo []byte, u ipaddr.IPAddr) []byte { - return u.ToBuffer(appendTo) -} - -// EncodeJSONValue encodes an already-byte-encoded JSON value with no value tag -// but with a length prefix, appends it to the supplied buffer, and returns the -// final buffer. -func EncodeJSONValue(appendTo []byte, colID uint32, data []byte) []byte { - appendTo = EncodeValueTag(appendTo, colID, JSON) - return EncodeUntaggedBytesValue(appendTo, data) -} - // DecodeValueTag decodes a value encoded by EncodeValueTag, used as a prefix in // each of the other EncodeFooValue methods. // @@ -2445,174 +948,6 @@ func DecodeValueTag(b []byte) (typeOffset int, dataOffset int, colID uint32, typ return typeOffset, dataOffset, colID, typ, nil } -// DecodeBoolValue decodes a value encoded by EncodeBoolValue. -func DecodeBoolValue(buf []byte) (remaining []byte, b bool, err error) { - _, dataOffset, _, typ, err := DecodeValueTag(buf) - if err != nil { - return buf, false, err - } - buf = buf[dataOffset:] - switch typ { - case True: - return buf, true, nil - case False: - return buf, false, nil - default: - return buf, false, fmt.Errorf("value type is not %v or %v: %v", True, False, typ) - } -} - -// DecodeIntValue decodes a value encoded by EncodeIntValue. -func DecodeIntValue(b []byte) (remaining []byte, i int64, err error) { - b, err = decodeValueTypeAssert(b, Int) - if err != nil { - return b, 0, err - } - return DecodeUntaggedIntValue(b) -} - -// DecodeUntaggedIntValue decodes a value encoded by EncodeUntaggedIntValue. -func DecodeUntaggedIntValue(b []byte) (remaining []byte, i int64, err error) { - b, _, i, err = DecodeNonsortingStdlibVarint(b) - return b, i, err -} - -// DecodeFloatValue decodes a value encoded by EncodeFloatValue. -func DecodeFloatValue(b []byte) (remaining []byte, f float64, err error) { - b, err = decodeValueTypeAssert(b, Float) - if err != nil { - return b, 0, err - } - return DecodeUntaggedFloatValue(b) -} - -// DecodeUntaggedFloatValue decodes a value encoded by EncodeUntaggedFloatValue. -func DecodeUntaggedFloatValue(b []byte) (remaining []byte, f float64, err error) { - if len(b) < 8 { - return b, 0, fmt.Errorf("float64 value should be exactly 8 bytes: %d", len(b)) - } - var i uint64 - b, i, err = DecodeUint64Ascending(b) - return b, math.Float64frombits(i), err -} - -// DecodeBytesValue decodes a value encoded by EncodeBytesValue. -func DecodeBytesValue(b []byte) (remaining []byte, data []byte, err error) { - b, err = decodeValueTypeAssert(b, Bytes) - if err != nil { - return b, nil, err - } - return DecodeUntaggedBytesValue(b) -} - -// DecodeUntaggedBytesValue decodes a value encoded by EncodeUntaggedBytesValue. -func DecodeUntaggedBytesValue(b []byte) (remaining, data []byte, err error) { - var i uint64 - b, _, i, err = DecodeNonsortingUvarint(b) - if err != nil { - return b, nil, err - } - return b[int(i):], b[:int(i)], nil -} - -// DecodeTimeValue decodes a value encoded by EncodeTimeValue. -func DecodeTimeValue(b []byte) (remaining []byte, t time.Time, err error) { - b, err = decodeValueTypeAssert(b, Time) - if err != nil { - return b, time.Time{}, err - } - return DecodeUntaggedTimeValue(b) -} - -// DecodeUntaggedTimeValue decodes a value encoded by EncodeUntaggedTimeValue. -func DecodeUntaggedTimeValue(b []byte) (remaining []byte, t time.Time, err error) { - var sec, nsec int64 - b, _, sec, err = DecodeNonsortingStdlibVarint(b) - if err != nil { - return b, time.Time{}, err - } - b, _, nsec, err = DecodeNonsortingStdlibVarint(b) - if err != nil { - return b, time.Time{}, err - } - return b, time.Unix(sec, nsec).UTC(), nil -} - -// DecodeTimeTZValue decodes a value encoded by EncodeTimeTZValue. -func DecodeTimeTZValue(b []byte) (remaining []byte, t timetz.TimeTZ, err error) { - b, err = decodeValueTypeAssert(b, TimeTZ) - if err != nil { - return b, timetz.TimeTZ{}, err - } - return DecodeUntaggedTimeTZValue(b) -} - -// DecodeUntaggedTimeTZValue decodes a value encoded by EncodeUntaggedTimeTZValue. -func DecodeUntaggedTimeTZValue(b []byte) (remaining []byte, t timetz.TimeTZ, err error) { - var timeOfDayMicros int64 - b, _, timeOfDayMicros, err = DecodeNonsortingStdlibVarint(b) - if err != nil { - return b, timetz.TimeTZ{}, err - } - var offsetSecs int64 - b, _, offsetSecs, err = DecodeNonsortingStdlibVarint(b) - if err != nil { - return b, timetz.TimeTZ{}, err - } - // Do not use timeofday.FromInt as it truncates 24:00 into 00:00. - return b, timetz.MakeTimeTZ(timeofday.TimeOfDay(timeOfDayMicros), int32(offsetSecs)), nil -} - -// DecodeDecimalValue decodes a value encoded by EncodeDecimalValue. -func DecodeDecimalValue(b []byte) (remaining []byte, d apd.Decimal, err error) { - b, err = decodeValueTypeAssert(b, Decimal) - if err != nil { - return b, apd.Decimal{}, err - } - return DecodeUntaggedDecimalValue(b) -} - -// DecodeUntaggedBox2DValue decodes a value encoded by EncodeUntaggedBox2DValue. -func DecodeUntaggedBox2DValue( - b []byte, -) (remaining []byte, box geo.CartesianBoundingBox, err error) { - box = geo.CartesianBoundingBox{} - remaining = b - - remaining, box.LoX, err = DecodeFloatAscending(remaining) - if err != nil { - return b, box, err - } - remaining, box.HiX, err = DecodeFloatAscending(remaining) - if err != nil { - return b, box, err - } - remaining, box.LoY, err = DecodeFloatAscending(remaining) - if err != nil { - return b, box, err - } - remaining, box.HiY, err = DecodeFloatAscending(remaining) - if err != nil { - return b, box, err - } - return remaining, box, err -} - -// DecodeUntaggedGeoValue decodes a value encoded by EncodeUntaggedGeoValue into -// the provided geopb.SpatialObject reference. The so parameter must already be -// empty/reset. -func DecodeUntaggedGeoValue(b []byte, so *geopb.SpatialObject) (remaining []byte, err error) { - var data []byte - remaining, data, err = DecodeUntaggedBytesValue(b) - if err != nil { - return b, err - } - // Not using protoutil.Unmarshal since the call to so.Reset() will waste the - // pre-allocated EWKB. - err = so.Unmarshal(data) - return remaining, err -} - // DecodeUntaggedDecimalValue decodes a value encoded by EncodeUntaggedDecimalValue. func DecodeUntaggedDecimalValue(b []byte) (remaining []byte, d apd.Decimal, err error) { var i uint64 @@ -2624,351 +959,10 @@ func DecodeUntaggedDecimalValue(b []byte) (remaining []byte, d apd.Decimal, err return b[int(i):], d, err } -// DecodeIntoUntaggedDecimalValue is like DecodeUntaggedDecimalValue except it -// writes the new Decimal into the input apd.Decimal pointer, which must be -// non-nil. -func DecodeIntoUntaggedDecimalValue(d *apd.Decimal, b []byte) (remaining []byte, err error) { - var i uint64 - b, _, i, err = DecodeNonsortingStdlibUvarint(b) - if err != nil { - return b, err - } - err = DecodeIntoNonsortingDecimal(d, b[:int(i)], nil) - return b[int(i):], err -} - -// DecodeDurationValue decodes a value encoded by EncodeUntaggedDurationValue. -func DecodeDurationValue(b []byte) (remaining []byte, d duration.Duration, err error) { - b, err = decodeValueTypeAssert(b, Duration) - if err != nil { - return b, duration.Duration{}, err - } - return DecodeUntaggedDurationValue(b) -} - -// DecodeUntaggedDurationValue decodes a value encoded by EncodeUntaggedDurationValue. -func DecodeUntaggedDurationValue(b []byte) (remaining []byte, d duration.Duration, err error) { - var months, days, nanos int64 - b, _, months, err = DecodeNonsortingStdlibVarint(b) - if err != nil { - return b, duration.Duration{}, err - } - b, _, days, err = DecodeNonsortingStdlibVarint(b) - if err != nil { - return b, duration.Duration{}, err - } - b, _, nanos, err = DecodeNonsortingStdlibVarint(b) - if err != nil { - return b, duration.Duration{}, err - } - return b, duration.DecodeDuration(months, days, nanos), nil -} - -// DecodeBitArrayValue decodes a value encoded by EncodeUntaggedBitArrayValue. -func DecodeBitArrayValue(b []byte) (remaining []byte, d utils.BitArray, err error) { - b, err = decodeValueTypeAssert(b, BitArray) - if err != nil { - return b, utils.BitArray{}, err - } - return DecodeUntaggedBitArrayValue(b) -} - -// DecodeUntaggedBitArrayValue decodes a value encoded by EncodeUntaggedBitArrayValue. -func DecodeUntaggedBitArrayValue(b []byte) (remaining []byte, d utils.BitArray, err error) { - var bitLen uint64 - b, _, bitLen, err = DecodeNonsortingUvarint(b) - if err != nil { - return b, utils.BitArray{}, err - } - words, lastBitsUsed := utils.EncodingPartsForBitLen(uint(bitLen)) - for i := range words { - var val uint64 - b, val, err = DecodeUint64Ascending(b) - if err != nil { - return b, utils.BitArray{}, err - } - words[i] = val - } - ba, err := utils.FromEncodingParts(words, lastBitsUsed) - return b, ba, err -} - const uuidValueEncodedLength = 16 var _ [uuidValueEncodedLength]byte = uuid.UUID{} // Assert that uuid.UUID is length 16. -// DecodeUUIDValue decodes a value encoded by EncodeUUIDValue. -func DecodeUUIDValue(b []byte) (remaining []byte, u uuid.UUID, err error) { - b, err = decodeValueTypeAssert(b, UUID) - if err != nil { - return b, u, err - } - return DecodeUntaggedUUIDValue(b) -} - -// DecodeUntaggedUUIDValue decodes a value encoded by EncodeUntaggedUUIDValue. -func DecodeUntaggedUUIDValue(b []byte) (remaining []byte, u uuid.UUID, err error) { - u, err = uuid.FromBytes(b[:uuidValueEncodedLength]) - if err != nil { - return b, uuid.UUID{}, err - } - return b[uuidValueEncodedLength:], u, nil -} - -// DecodeIPAddrValue decodes a value encoded by EncodeIPAddrValue. -func DecodeIPAddrValue(b []byte) (remaining []byte, u ipaddr.IPAddr, err error) { - b, err = decodeValueTypeAssert(b, IPAddr) - if err != nil { - return b, u, err - } - return DecodeUntaggedIPAddrValue(b) -} - -// DecodeUntaggedIPAddrValue decodes a value encoded by EncodeUntaggedIPAddrValue. -func DecodeUntaggedIPAddrValue(b []byte) (remaining []byte, u ipaddr.IPAddr, err error) { - remaining, err = u.FromBuffer(b) - return remaining, u, err -} - -func decodeValueTypeAssert(b []byte, expected Type) ([]byte, error) { - _, dataOffset, _, typ, err := DecodeValueTag(b) - if err != nil { - return b, err - } - b = b[dataOffset:] - if typ != expected { - return b, errors.Errorf("value type is not %s: %s", expected, typ) - } - return b, nil -} - -// PeekValueLength returns the length of the encoded value at the start of b. -// Note: If this function succeeds, it's not a guarantee that decoding the value -// will succeed. -// -// `b` can point either at beginning of the "full tag" with the column id, or it -// can point to the beginning of the type part of the tag, as indicated by the -// `typeOffset` returned by this or DecodeValueTag. -// -// The length returned is the full length of the encoded value, including the -// entire tag. -func PeekValueLength(b []byte) (typeOffset int, length int, err error) { - if len(b) == 0 { - return 0, 0, nil - } - var dataOffset int - var typ Type - typeOffset, dataOffset, _, typ, err = DecodeValueTag(b) - if err != nil { - return 0, 0, err - } - length, err = PeekValueLengthWithOffsetsAndType(b, dataOffset, typ) - return typeOffset, length, err -} - -// PeekValueLengthWithOffsetsAndType is the same as PeekValueLength, except it -// expects a dataOffset and typ value from a previous call to DecodeValueTag -// on its input byte slice. Use this if you've already called DecodeValueTag -// on the input for another reason, to avoid it getting called twice. -func PeekValueLengthWithOffsetsAndType(b []byte, dataOffset int, typ Type) (length int, err error) { - b = b[dataOffset:] - switch typ { - case Null: - return dataOffset, nil - case True, False: - return dataOffset, nil - case Int: - _, n, _, err := DecodeNonsortingStdlibVarint(b) - return dataOffset + n, err - case Float: - return dataOffset + floatValueEncodedLength, nil - case Bytes, Array, JSON, Geo: - _, n, i, err := DecodeNonsortingUvarint(b) - return dataOffset + n + int(i), err - case Box2D: - length, err := peekBox2DLength(b) - if err != nil { - return 0, err - } - return dataOffset + length, nil - case BitArray: - _, n, bitLen, err := DecodeNonsortingUvarint(b) - if err != nil { - return 0, err - } - numWords, _ := utils.SizesForBitLen(uint(bitLen)) - return dataOffset + n + int(numWords)*8, err - case Tuple: - rem, l, numTuples, err := DecodeNonsortingUvarint(b) - if err != nil { - return 0, errors.Wrapf(err, "cannot decode tuple header: ") - } - for i := 0; i < int(numTuples); i++ { - _, entryLen, err := PeekValueLength(rem) - if err != nil { - return 0, errors.Wrapf(err, "cannot peek tuple entry %d", i) - } - l += entryLen - rem = rem[entryLen:] - } - return dataOffset + l, nil - case Decimal: - _, n, i, err := DecodeNonsortingStdlibUvarint(b) - return dataOffset + n + int(i), err - case Time, TimeTZ: - n, err := getMultiNonsortingVarintLen(b, 2) - return dataOffset + n, err - case Duration: - n, err := getMultiNonsortingVarintLen(b, 3) - return dataOffset + n, err - case UUID: - return dataOffset + uuidValueEncodedLength, err - case IPAddr: - family := ipaddr.IPFamily(b[0]) - if family == ipaddr.IPv4family { - return dataOffset + ipaddr.IPv4size, err - } else if family == ipaddr.IPv6family { - return dataOffset + ipaddr.IPv6size, err - } - return 0, errors.Errorf("got invalid INET IP family: %d", family) - default: - return 0, errors.Errorf("unknown type %s", typ) - } -} - -// PrintableBytes returns true iff the given byte array is a valid -// UTF-8 sequence and it is printable. -func PrintableBytes(b []byte) bool { - return len(bytes.TrimLeftFunc(b, isValidAndPrintableRune)) == 0 -} - -func isValidAndPrintableRune(r rune) bool { - return r != utf8.RuneError && unicode.IsPrint(r) -} - -// PrettyPrintValueEncoded returns a string representation of the first -// decodable value in the provided byte slice, along with the remaining byte -// slice after decoding. -func PrettyPrintValueEncoded(b []byte) ([]byte, string, error) { - _, dataOffset, _, typ, err := DecodeValueTag(b) - if err != nil { - return b, "", err - } - switch typ { - case Null: - b = b[dataOffset:] - return b, "NULL", nil - case True: - b = b[dataOffset:] - return b, "true", nil - case False: - b = b[dataOffset:] - return b, "false", nil - case Int: - var i int64 - b, i, err = DecodeIntValue(b) - if err != nil { - return b, "", err - } - return b, strconv.FormatInt(i, 10), nil - case Float: - var f float64 - b, f, err = DecodeFloatValue(b) - if err != nil { - return b, "", err - } - return b, strconv.FormatFloat(f, 'g', -1, 64), nil - case Decimal: - var d apd.Decimal - b, d, err = DecodeDecimalValue(b) - if err != nil { - return b, "", err - } - return b, d.String(), nil - case Bytes: - var data []byte - b, data, err = DecodeBytesValue(b) - if err != nil { - return b, "", err - } - if PrintableBytes(data) { - return b, string(data), nil - } - // The following code extends hex.EncodeToString(). - dst := make([]byte, 2+hex.EncodedLen(len(data))) - dst[0], dst[1] = '0', 'x' - hex.Encode(dst[2:], data) - return b, string(dst), nil - case Time: - var t time.Time - b, t, err = DecodeTimeValue(b) - if err != nil { - return b, "", err - } - return b, t.UTC().Format(time.RFC3339Nano), nil - case TimeTZ: - var t timetz.TimeTZ - b, t, err = DecodeTimeTZValue(b) - if err != nil { - return b, "", err - } - return b, t.String(), nil - case Duration: - var d duration.Duration - b, d, err = DecodeDurationValue(b) - if err != nil { - return b, "", err - } - return b, d.StringNanos(), nil - case BitArray: - var d utils.BitArray - b, d, err = DecodeBitArrayValue(b) - if err != nil { - return b, "", err - } - return b, "B" + d.String(), nil - case UUID: - var u uuid.UUID - b, u, err = DecodeUUIDValue(b) - if err != nil { - return b, "", err - } - return b, u.String(), nil - case IPAddr: - var ipAddr ipaddr.IPAddr - b, ipAddr, err = DecodeIPAddrValue(b) - if err != nil { - return b, "", err - } - return b, ipAddr.String(), nil - default: - return b, "", errors.Errorf("unknown type %s", typ) - } -} - -// DecomposeKeyTokens breaks apart a key into its individual key-encoded values -// and returns a slice of byte slices, one for each key-encoded value. -// It also returns whether the key contains a NULL value. -func DecomposeKeyTokens(b []byte) (tokens [][]byte, containsNull bool, err error) { - var out [][]byte - - for len(b) > 0 { - tokenLen, err := PeekLength(b) - if err != nil { - return nil, false, err - } - - if PeekType(b) == Null { - containsNull = true - } - - out = append(out, b[:tokenLen]) - b = b[tokenLen:] - } - - return out, containsNull, nil -} - // getInvertedIndexKeyLength finds the length of an inverted index key // encoded as a byte array. func getInvertedIndexKeyLength(b []byte) (int, error) { @@ -3008,69 +1002,6 @@ func getJSONInvertedIndexKeyLength(buf []byte) (int, error) { } } -// EncodeArrayKeyMarker adds the array key encoding marker to buf and -// returns the new buffer. -func EncodeArrayKeyMarker(buf []byte, dir Direction) []byte { - switch dir { - case Ascending: - return append(buf, arrayKeyMarker) - case Descending: - return append(buf, arrayKeyDescendingMarker) - default: - panic("invalid direction") - } -} - -// EncodeArrayKeyTerminator adds the array key terminator to buf and -// returns the new buffer. -func EncodeArrayKeyTerminator(buf []byte, dir Direction) []byte { - switch dir { - case Ascending: - return append(buf, arrayKeyTerminator) - case Descending: - return append(buf, arrayKeyDescendingTerminator) - default: - panic("invalid direction") - } -} - -// EncodeNullWithinArrayKey encodes NULL within a key encoded array. -func EncodeNullWithinArrayKey(buf []byte, dir Direction) []byte { - switch dir { - case Ascending: - return append(buf, ascendingNullWithinArrayKey) - case Descending: - return append(buf, descendingNullWithinArrayKey) - default: - panic("invalid direction") - } -} - -// IsNextByteArrayEncodedNull returns if the first byte in the input -// is the NULL encoded byte within an array key. -func IsNextByteArrayEncodedNull(buf []byte, dir Direction) bool { - expected := ascendingNullWithinArrayKey - if dir == Descending { - expected = descendingNullWithinArrayKey - } - return buf[0] == expected -} - -// ValidateAndConsumeArrayKeyMarker checks that the marker at the front -// of buf is valid for an array of the given direction, and consumes it -// if so. It returns an error if the tag is invalid. -func ValidateAndConsumeArrayKeyMarker(buf []byte, dir Direction) ([]byte, error) { - typ := PeekType(buf) - expected := ArrayKeyAsc - if dir == Descending { - expected = ArrayKeyDesc - } - if typ != expected { - return nil, errors.Newf("invalid type found %s", typ) - } - return buf[1:], nil -} - // IsArrayKeyDone returns if the first byte in the input is the array // terminator for the input direction. func IsArrayKeyDone(buf []byte, dir Direction) bool { diff --git a/postgres/parser/geo/geodist/geodist.go b/postgres/parser/geo/geodist/geodist.go deleted file mode 100644 index 49eb8b7828..0000000000 --- a/postgres/parser/geo/geodist/geodist.go +++ /dev/null @@ -1,444 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// Package geodist finds distances between two geospatial shapes. -package geodist - -import ( - "github.com/cockroachdb/errors" - "github.com/golang/geo/s2" - "github.com/twpayne/go-geom" -) - -// Point is a union of the point types used in geometry and geography representation. -// The interfaces for distance calculation defined below are shared for both representations, -// and this union helps us avoid heap allocations by doing cheap copy-by-value of points. The -// code that peers inside a Point knows which of the two fields is populated. -type Point struct { - GeomPoint geom.Coord - GeogPoint s2.Point -} - -// IsShape implements the geodist.Shape interface. -func (p *Point) IsShape() {} - -// Edge is a struct that represents a connection between two points. -type Edge struct { - V0, V1 Point -} - -// LineString is an interface that represents a geospatial LineString. -type LineString interface { - Edge(i int) Edge - NumEdges() int - Vertex(i int) Point - NumVertexes() int - IsShape() - IsLineString() -} - -// LinearRing is an interface that represents a geospatial LinearRing. -type LinearRing interface { - Edge(i int) Edge - NumEdges() int - Vertex(i int) Point - NumVertexes() int - IsShape() - IsLinearRing() -} - -// shapeWithEdges represents any shape that contains edges. -type shapeWithEdges interface { - Edge(i int) Edge - NumEdges() int -} - -// Polygon is an interface that represents a geospatial Polygon. -type Polygon interface { - LinearRing(i int) LinearRing - NumLinearRings() int - IsShape() - IsPolygon() -} - -// Shape is an interface that represents any Geospatial shape. -type Shape interface { - IsShape() -} - -var _ Shape = (*Point)(nil) -var _ Shape = (LineString)(nil) -var _ Shape = (LinearRing)(nil) -var _ Shape = (Polygon)(nil) - -// DistanceUpdater is a provided hook that has a series of functions that allows -// the caller to maintain the distance value desired. -type DistanceUpdater interface { - // Update updates the distance based on two provided points, - // returning if the function should return early. - Update(a Point, b Point) bool - // OnIntersects is called when two shapes intersects. - OnIntersects(p Point) bool - // Distance returns the distance to return so far. - Distance() float64 - // IsMaxDistance returns whether the updater is looking for maximum distance. - IsMaxDistance() bool - // FlipGeometries is called to flip the order of geometries. - FlipGeometries() -} - -// EdgeCrosser is a provided hook that calculates whether edges intersect. -type EdgeCrosser interface { - // ChainCrossing assumes there is an edge to compare against, and the previous - // point `p0` is the start of the next edge. It will then returns whether (p0, p) - // intersects with the edge and point of intersection if they intersect. - // When complete, point p will become p0. - // Desired usage examples: - // crosser := NewEdgeCrosser(edge.V0, edge.V1, startingP0) - // intersects, _ := crosser.ChainCrossing(p1) - // laterIntersects, _ := crosser.ChainCrossing(p2) - // intersects |= laterIntersects .... - ChainCrossing(p Point) (bool, Point) -} - -// DistanceCalculator contains calculations which allow ShapeDistance to calculate -// the distance between two shapes. -type DistanceCalculator interface { - // DistanceUpdater returns the DistanceUpdater for the current set of calculations. - DistanceUpdater() DistanceUpdater - // NewEdgeCrosser returns a new EdgeCrosser with the given edge initialized to be - // the edge to compare against, and the start point to be the start of the first - // edge to compare against. - NewEdgeCrosser(edge Edge, startPoint Point) EdgeCrosser - // PointInLinearRing returns whether the point is inside the given linearRing. - PointInLinearRing(point Point, linearRing LinearRing) bool - // ClosestPointToEdge returns the closest point to the infinite line denoted by - // the edge, and a bool on whether this point lies on the edge segment. - ClosestPointToEdge(edge Edge, point Point) (Point, bool) - // BoundingBoxIntersects returns whether the bounding boxes of the shapes in - // question intersect. - BoundingBoxIntersects() bool -} - -// ShapeDistance returns the distance between two given shapes. -// Distance is defined by the DistanceUpdater provided by the interface. -// It returns whether the function above should return early. -func ShapeDistance(c DistanceCalculator, a Shape, b Shape) (bool, error) { - switch a := a.(type) { - case *Point: - switch b := b.(type) { - case *Point: - return c.DistanceUpdater().Update(*a, *b), nil - case LineString: - return onPointToLineString(c, *a, b), nil - case Polygon: - return onPointToPolygon(c, *a, b), nil - default: - return false, errors.Newf("unknown shape: %T", b) - } - case LineString: - switch b := b.(type) { - case *Point: - c.DistanceUpdater().FlipGeometries() - // defer to restore the order of geometries at the end of the function call. - defer c.DistanceUpdater().FlipGeometries() - return onPointToLineString(c, *b, a), nil - case LineString: - return onShapeEdgesToShapeEdges(c, a, b), nil - case Polygon: - return onLineStringToPolygon(c, a, b), nil - default: - return false, errors.Newf("unknown shape: %T", b) - } - case Polygon: - switch b := b.(type) { - case *Point: - c.DistanceUpdater().FlipGeometries() - // defer to restore the order of geometries at the end of the function call. - defer c.DistanceUpdater().FlipGeometries() - return onPointToPolygon(c, *b, a), nil - case LineString: - c.DistanceUpdater().FlipGeometries() - // defer to restore the order of geometries at the end of the function call. - defer c.DistanceUpdater().FlipGeometries() - return onLineStringToPolygon(c, b, a), nil - case Polygon: - return onPolygonToPolygon(c, a, b), nil - default: - return false, errors.Newf("unknown shape: %T", b) - } - } - return false, errors.Newf("unknown shape: %T", a) -} - -// onPointToEdgesExceptFirstEdgeStart updates the distance against the edges of a shape and a point. -// It will only check the V1 of each edge and assumes the first edge start does not need the distance -// to be computed. -func onPointToEdgesExceptFirstEdgeStart(c DistanceCalculator, a Point, b shapeWithEdges) bool { - for edgeIdx := 0; edgeIdx < b.NumEdges(); edgeIdx++ { - edge := b.Edge(edgeIdx) - // Check against all V1 of every edge. - if c.DistanceUpdater().Update(a, edge.V1) { - return true - } - // The max distance between a point and the set of points representing an edge is the - // maximum distance from the point and the pair of end-points of the edge, so we don't - // need to update the distance using the projected point. - if !c.DistanceUpdater().IsMaxDistance() { - // Also project the point to the infinite line of the edge, and compare if the closestPoint - // lies on the edge. - if closestPoint, ok := c.ClosestPointToEdge(edge, a); ok { - if c.DistanceUpdater().Update(a, closestPoint) { - return true - } - } - } - } - return false -} - -// onPointToLineString updates the distance between a point and a polyline. -// Returns true if the calling function should early exit. -func onPointToLineString(c DistanceCalculator, a Point, b LineString) bool { - // Compare the first point, to avoid checking each V0 in the chain afterwards. - if c.DistanceUpdater().Update(a, b.Vertex(0)) { - return true - } - return onPointToEdgesExceptFirstEdgeStart(c, a, b) -} - -// onPointToPolygon updates the distance between a point and a polygon. -// Returns true if the calling function should early exit. -func onPointToPolygon(c DistanceCalculator, a Point, b Polygon) bool { - // MaxDistance: When computing the maximum distance, the cases are: - // - The point P is not contained in the exterior of the polygon G. - // Say vertex V is the vertex of the exterior of the polygon that is - // furthest away from point P (among all the exterior vertices). - // - One can prove that any vertex of the holes will be closer to point P than vertex V. - // Similarly we can prove that any point in the interior of the polygin is closer to P than vertex V. - // Therefore we only need to compare with the exterior. - // - The point P is contained in the exterior and inside a hole of polygon G. - // One can again prove that the furthest point in the polygon from P is one of the vertices of the exterior. - // - The point P is contained in the polygon. One can again prove the same property. - // So we only need to compare with the exterior ring. - // MinDistance: If the exterior ring does not contain the point, we just need to calculate the distance to - // the exterior ring. - // BoundingBoxIntersects: if the bounding box of the shape being calculated does not intersect, - // then we only need to compare the outer loop. - if c.DistanceUpdater().IsMaxDistance() || !c.BoundingBoxIntersects() || !c.PointInLinearRing(a, b.LinearRing(0)) { - return onPointToEdgesExceptFirstEdgeStart(c, a, b.LinearRing(0)) - } - // At this point it may be inside a hole. - // If it is in a hole, return the distance to the hole. - for ringIdx := 1; ringIdx < b.NumLinearRings(); ringIdx++ { - ring := b.LinearRing(ringIdx) - if c.PointInLinearRing(a, ring) { - return onPointToEdgesExceptFirstEdgeStart(c, a, ring) - } - } - - // Otherwise, we are inside the polygon. - return c.DistanceUpdater().OnIntersects(a) -} - -// onShapeEdgesToShapeEdges updates the distance between two shapes by -// only looking at the edges. -// Returns true if the calling function should early exit. -func onShapeEdgesToShapeEdges(c DistanceCalculator, a shapeWithEdges, b shapeWithEdges) bool { - for aEdgeIdx := 0; aEdgeIdx < a.NumEdges(); aEdgeIdx++ { - aEdge := a.Edge(aEdgeIdx) - var crosser EdgeCrosser - // MaxDistance: the max distance between 2 edges is the maximum of the distance across - // pairs of vertices chosen from each edge. - // It does not matter whether the edges cross, so we skip this check. - // BoundingBoxIntersects: if the bounding box of the two shapes do not intersect, - // then we don't need to check whether edges intersect either. - if !c.DistanceUpdater().IsMaxDistance() && c.BoundingBoxIntersects() { - crosser = c.NewEdgeCrosser(aEdge, b.Edge(0).V0) - } - for bEdgeIdx := 0; bEdgeIdx < b.NumEdges(); bEdgeIdx++ { - bEdge := b.Edge(bEdgeIdx) - if crosser != nil { - // If the edges cross, the distance is 0. - intersects, intersectionPoint := crosser.ChainCrossing(bEdge.V1) - if intersects { - return c.DistanceUpdater().OnIntersects(intersectionPoint) - } - } - - // Check the vertex against the ends of the edges. - if c.DistanceUpdater().Update(aEdge.V0, bEdge.V0) || - c.DistanceUpdater().Update(aEdge.V0, bEdge.V1) || - c.DistanceUpdater().Update(aEdge.V1, bEdge.V0) || - c.DistanceUpdater().Update(aEdge.V1, bEdge.V1) { - return true - } - // Only project vertexes to edges if we are looking at the edges. - if !c.DistanceUpdater().IsMaxDistance() { - if projectVertexToEdge(c, aEdge.V0, bEdge) || - projectVertexToEdge(c, aEdge.V1, bEdge) { - return true - } - c.DistanceUpdater().FlipGeometries() - if projectVertexToEdge(c, bEdge.V0, aEdge) || - projectVertexToEdge(c, bEdge.V1, aEdge) { - // Restore the order of geometries. - c.DistanceUpdater().FlipGeometries() - return true - } - // Restore the order of geometries. - c.DistanceUpdater().FlipGeometries() - } - } - } - return false -} - -// projectVertexToEdge attempts to project the point onto the given edge. -// Returns true if the calling function should early exit. -func projectVertexToEdge(c DistanceCalculator, vertex Point, edge Edge) bool { - // Also check the projection of the vertex onto the edge. - if closestPoint, ok := c.ClosestPointToEdge(edge, vertex); ok { - if c.DistanceUpdater().Update(vertex, closestPoint) { - return true - } - } - return false -} - -// onLineStringToPolygon updates the distance between a polyline and a polygon. -// Returns true if the calling function should early exit. -func onLineStringToPolygon(c DistanceCalculator, a LineString, b Polygon) bool { - // MinDistance: If we know at least one point is outside the exterior ring, then there are two cases: - // * the line is always outside the exterior ring. We only need to compare the line - // against the exterior ring. - // * the line intersects with the exterior ring. - // In both these cases, we can defer to the edge to edge comparison between the line - // and the exterior ring. - // We use the first point of the linestring for this check. - // MaxDistance: the furthest distance from a LineString to a Polygon is always against the - // exterior ring. This follows the reasoning under "onPointToPolygon", but we must now - // check each point in the LineString. - // BoundingBoxIntersects: if the bounding box of the two shapes do not intersect, - // then the distance is always from the LineString to the exterior ring. - if c.DistanceUpdater().IsMaxDistance() || - !c.BoundingBoxIntersects() || - !c.PointInLinearRing(a.Vertex(0), b.LinearRing(0)) { - return onShapeEdgesToShapeEdges(c, a, b.LinearRing(0)) - } - - // Now we are guaranteed that there is at least one point inside the exterior ring. - // - // For a polygon with no holes, the fact that there is a point inside the exterior - // ring would imply that the distance is zero. - // - // However, when there are holes, it is possible that the distance is non-zero if - // polyline A is completely contained inside a hole. We iterate over the holes and - // compute the distance between the hole and polyline A. - // * If polyline A is within the given distance, we can immediately return. - // * If polyline A does not intersect the hole but there is at least one point inside - // the hole, must be inside that hole and so the distance of this polyline to this hole - // is the distance of this polyline to this polygon. - for ringIdx := 1; ringIdx < b.NumLinearRings(); ringIdx++ { - hole := b.LinearRing(ringIdx) - if onShapeEdgesToShapeEdges(c, a, hole) { - return true - } - for pointIdx := 0; pointIdx < a.NumVertexes(); pointIdx++ { - if c.PointInLinearRing(a.Vertex(pointIdx), hole) { - return false - } - } - } - - // This means we are inside the exterior ring, and no points are inside a hole. - // This means the point is inside the polygon. - return c.DistanceUpdater().OnIntersects(a.Vertex(0)) -} - -// onPolygonToPolygon updates the distance between two polygons. -// Returns true if the calling function should early exit. -func onPolygonToPolygon(c DistanceCalculator, a Polygon, b Polygon) bool { - aFirstPoint := a.LinearRing(0).Vertex(0) - bFirstPoint := b.LinearRing(0).Vertex(0) - - // MinDistance: - // If there is at least one point on the the exterior ring of B that is outside the exterior ring - // of A, then we have one of these two cases: - // * The exterior rings of A and B intersect. The distance can always be found by comparing - // the exterior rings. - // * The exterior rings of A and B never meet. This distance can always be found - // by only comparing the exterior rings. - // If we find the point is inside the exterior ring, A could contain B, so this reasoning - // does not apply. - // - // The same reasoning applies if there is at least one point on the exterior ring of A - // that is outside the exterior ring of B. - // - // As such, we only need to compare the exterior rings if we detect this. - // - // MaxDistance: - // The furthest distance between two polygons is always against the exterior rings of each other. - // This closely follows the reasoning pointed out in "onPointToPolygon". Holes are always located - // inside the exterior ring of a polygon, so the exterior ring will always contain a point - // with a larger max distance. - // BoundingBoxIntersects: if the bounding box of the two shapes do not intersect, - // then the distance is always between the two exterior rings. - if c.DistanceUpdater().IsMaxDistance() || - !c.BoundingBoxIntersects() || - !c.PointInLinearRing(bFirstPoint, a.LinearRing(0)) && !c.PointInLinearRing(aFirstPoint, b.LinearRing(0)) { - return onShapeEdgesToShapeEdges(c, a.LinearRing(0), b.LinearRing(0)) - } - - // If any point of polygon A is inside a hole of polygon B, then either: - // * A is inside the hole and the closest point can be found by comparing A's outer - // linearRing and the hole in B, or - // * A intersects this hole and the distance is zero, which can also be found by comparing - // A's outer linearRing and the hole in B. - // In this case, we only need to compare the holes of B to contain a single point A. - for ringIdx := 1; ringIdx < b.NumLinearRings(); ringIdx++ { - bHole := b.LinearRing(ringIdx) - if c.PointInLinearRing(aFirstPoint, bHole) { - return onShapeEdgesToShapeEdges(c, a.LinearRing(0), bHole) - } - } - - // Do the same check for the polygons the other way around. - c.DistanceUpdater().FlipGeometries() - // defer to restore the order of geometries at the end of the function call. - defer c.DistanceUpdater().FlipGeometries() - for ringIdx := 1; ringIdx < a.NumLinearRings(); ringIdx++ { - aHole := a.LinearRing(ringIdx) - if c.PointInLinearRing(bFirstPoint, aHole) { - return onShapeEdgesToShapeEdges(c, b.LinearRing(0), aHole) - } - } - - // Now we know either a point of the exterior ring A is definitely inside polygon B - // or vice versa. This is an intersection. - if c.PointInLinearRing(aFirstPoint, b.LinearRing(0)) { - return c.DistanceUpdater().OnIntersects(aFirstPoint) - } - return c.DistanceUpdater().OnIntersects(bFirstPoint) -} diff --git a/postgres/parser/geo/geogen/geogen.go b/postgres/parser/geo/geogen/geogen.go deleted file mode 100644 index 902c0a820c..0000000000 --- a/postgres/parser/geo/geogen/geogen.go +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// Package geogen provides utilities for generating various geospatial types. -package geogen - -import ( - "math" - "math/rand" - "sort" - - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" - "github.com/dolthub/doltgresql/postgres/parser/geo/geoprojbase" -) - -var validShapeTypes = []geopb.ShapeType{ - geopb.ShapeType_Point, - geopb.ShapeType_LineString, - geopb.ShapeType_Polygon, - geopb.ShapeType_MultiPoint, - geopb.ShapeType_MultiLineString, - geopb.ShapeType_MultiPolygon, - geopb.ShapeType_GeometryCollection, -} - -// RandomCoord generates a random coord in the given bounds. -func RandomCoord(rng *rand.Rand, min float64, max float64) float64 { - return rng.Float64()*(max-min) + min -} - -// RandomValidLinearRingCoords generates a flat float64 array of coordinates that represents -// a completely closed shape that can represent a simple LinearRing. This shape is always valid. -// A LinearRing must have at least 3 points. A point is added at the end to close the ring. -// Implements the algorithm in https://observablehq.com/@tarte0/generate-random-simple-polygon. -func RandomValidLinearRingCoords( - rng *rand.Rand, numPoints int, minX float64, maxX float64, minY float64, maxY float64, -) []geom.Coord { - if numPoints < 3 { - panic(errors.Newf("need at least 3 points, got %d", numPoints)) - } - // Generate N random points, and find the center. - coords := make([]geom.Coord, numPoints+1) - var centerX, centerY float64 - for i := 0; i < numPoints; i++ { - coords[i] = geom.Coord{ - RandomCoord(rng, minX, maxX), - RandomCoord(rng, minY, maxY), - } - centerX += coords[i].X() - centerY += coords[i].Y() - } - - centerX /= float64(numPoints) - centerY /= float64(numPoints) - - // Sort by the angle of all the points relative to the center. - // Use ascending order of angle to get a CCW loop. - sort.Slice(coords[:numPoints], func(i, j int) bool { - angleI := math.Atan2(coords[i].Y()-centerY, coords[i].X()-centerX) - angleJ := math.Atan2(coords[j].Y()-centerY, coords[j].X()-centerX) - return angleI < angleJ - }) - - // Append the first coordinate to the end. - coords[numPoints] = coords[0] - return coords -} - -// RandomPoint generates a random Point. -func RandomPoint( - rng *rand.Rand, minX float64, maxX float64, minY float64, maxY float64, srid geopb.SRID, -) *geom.Point { - return geom.NewPointFlat(geom.XY, []float64{ - RandomCoord(rng, minX, maxX), - RandomCoord(rng, minY, maxY), - }).SetSRID(int(srid)) -} - -// RandomLineString generates a random LineString. -func RandomLineString( - rng *rand.Rand, minX float64, maxX float64, minY float64, maxY float64, srid geopb.SRID, -) *geom.LineString { - numCoords := 3 + rand.Intn(10) - randCoords := RandomValidLinearRingCoords(rng, numCoords, minX, maxX, minY, maxY) - - // Extract a random substring from the LineString by truncating at the ends. - var minTrunc, maxTrunc int - // Ensure we always have at least two points. - for maxTrunc-minTrunc < 2 { - minTrunc, maxTrunc = rand.Intn(numCoords+1), rand.Intn(numCoords+1) - // Ensure maxTrunc >= minTrunc. - if minTrunc > maxTrunc { - minTrunc, maxTrunc = maxTrunc, minTrunc - } - } - return geom.NewLineString(geom.XY).MustSetCoords(randCoords[minTrunc:maxTrunc]).SetSRID(int(srid)) -} - -// RandomPolygon generates a random Polygon. -func RandomPolygon( - rng *rand.Rand, minX float64, maxX float64, minY float64, maxY float64, srid geopb.SRID, -) *geom.Polygon { - // TODO(otan): generate random holes inside the Polygon. - // Ideas: - // * We can do something like use 4 arbitrary points in the LinearRing to generate a BoundingBox, - // and re-use "PointInLinearRing" to generate N random points inside the 4 points to form - // a "sub" linear ring inside. - // * Generate a random set of polygons, see which ones they fully cover and use that. - return geom.NewPolygon(geom.XY).MustSetCoords([][]geom.Coord{ - RandomValidLinearRingCoords(rng, 3+rng.Intn(10), minX, maxX, minY, maxY), - }).SetSRID(int(srid)) -} - -// RandomGeomT generates a random geom.T object within the given bounds and SRID. -func RandomGeomT( - rng *rand.Rand, minX float64, maxX float64, minY float64, maxY float64, srid geopb.SRID, -) geom.T { - shapeType := validShapeTypes[rng.Intn(len(validShapeTypes))] - switch shapeType { - case geopb.ShapeType_Point: - return RandomPoint(rng, minX, maxX, minY, maxY, srid) - case geopb.ShapeType_LineString: - return RandomLineString(rng, minX, maxX, minY, maxY, srid) - case geopb.ShapeType_Polygon: - return RandomPolygon(rng, minX, maxX, minY, maxY, srid) - case geopb.ShapeType_MultiPoint: - // TODO(otan): add empty points. - ret := geom.NewMultiPoint(geom.XY).SetSRID(int(srid)) - num := 1 + rng.Intn(10) - for i := 0; i < num; i++ { - if err := ret.Push(RandomPoint(rng, minX, maxX, minY, maxY, srid)); err != nil { - panic(err) - } - } - return ret - case geopb.ShapeType_MultiLineString: - // TODO(otan): add empty LineStrings. - ret := geom.NewMultiLineString(geom.XY).SetSRID(int(srid)) - num := 1 + rng.Intn(10) - for i := 0; i < num; i++ { - if err := ret.Push(RandomLineString(rng, minX, maxX, minY, maxY, srid)); err != nil { - panic(err) - } - } - return ret - case geopb.ShapeType_MultiPolygon: - // TODO(otan): add empty Polygons. - ret := geom.NewMultiPolygon(geom.XY).SetSRID(int(srid)) - num := 1 + rng.Intn(10) - for i := 0; i < num; i++ { - if err := ret.Push(RandomPolygon(rng, minX, maxX, minY, maxY, srid)); err != nil { - panic(err) - } - } - return ret - case geopb.ShapeType_GeometryCollection: - ret := geom.NewGeometryCollection().SetSRID(int(srid)) - num := 1 + rng.Intn(10) - for i := 0; i < num; i++ { - var shape geom.T - needShape := true - // Keep searching for a non GeometryCollection. - for needShape { - shape = RandomGeomT(rng, minX, maxX, minY, maxY, srid) - _, needShape = shape.(*geom.GeometryCollection) - } - if err := ret.Push(shape); err != nil { - panic(err) - } - } - return ret - } - panic(errors.Newf("unknown shape type: %v", shapeType)) -} - -// RandomGeometry generates a random Geometry with the given SRID. -func RandomGeometry(rng *rand.Rand, srid geopb.SRID) geo.Geometry { - minX, maxX := -math.MaxFloat32, math.MaxFloat32 - minY, maxY := -math.MaxFloat32, math.MaxFloat32 - proj, ok := geoprojbase.Projections[srid] - if ok { - minX, maxX = proj.Bounds.MinX, proj.Bounds.MaxX - minY, maxY = proj.Bounds.MinY, proj.Bounds.MaxY - } - ret, err := geo.MakeGeometryFromGeomT(RandomGeomT(rng, minX, maxX, minY, maxY, srid)) - if err != nil { - panic(err) - } - return ret -} - -// RandomGeography generates a random Geometry with the given SRID. -func RandomGeography(rng *rand.Rand, srid geopb.SRID) geo.Geography { - // TODO(otan): generate geographies that traverse latitude/longitude boundaries. - minX, maxX := -180.0, 180.0 - minY, maxY := -90.0, 90.0 - ret, err := geo.MakeGeographyFromGeomT(RandomGeomT(rng, minX, maxX, minY, maxY, srid)) - if err != nil { - panic(err) - } - return ret -} diff --git a/postgres/parser/geo/geogfn/azimuth.go b/postgres/parser/geo/geogfn/azimuth.go deleted file mode 100644 index b631bb86fe..0000000000 --- a/postgres/parser/geo/geogfn/azimuth.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/golang/geo/s2" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// Azimuth returns the azimuth in radians of the segment defined by the given point geometries. -// The azimuth is angle is referenced from north, and is positive clockwise. -// North = 0; East = Ï€/2; South = Ï€; West = 3Ï€/2. -// Returns nil if the two points are the same. -// Returns an error if any of the two Geography items are not points. -func Azimuth(a geo.Geography, b geo.Geography) (*float64, error) { - if a.SRID() != b.SRID() { - return nil, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - - aGeomT, err := a.AsGeomT() - if err != nil { - return nil, err - } - - aPoint, ok := aGeomT.(*geom.Point) - if !ok { - return nil, errors.Newf("arguments must be POINT geometries") - } - - bGeomT, err := b.AsGeomT() - if err != nil { - return nil, err - } - - bPoint, ok := bGeomT.(*geom.Point) - if !ok { - return nil, errors.Newf("arguments must be POINT geometries") - } - - if aPoint.Empty() || bPoint.Empty() { - return nil, errors.Newf("cannot call ST_Azimuth with POINT EMPTY") - } - - if aPoint.X() == bPoint.X() && aPoint.Y() == bPoint.Y() { - return nil, nil - } - - s, err := a.Spheroid() - if err != nil { - return nil, err - } - - _, az1, _ := s.Inverse( - s2.LatLngFromDegrees(aPoint.Y(), aPoint.X()), - s2.LatLngFromDegrees(bPoint.Y(), bPoint.X()), - ) - // Convert to radians. - az1 = az1 * math.Pi / 180 - return &az1, nil -} diff --git a/postgres/parser/geo/geogfn/best_projection.go b/postgres/parser/geo/geogfn/best_projection.go deleted file mode 100644 index ecf0afc80f..0000000000 --- a/postgres/parser/geo/geogfn/best_projection.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "fmt" - "math" - - "github.com/cockroachdb/errors" - "github.com/golang/geo/s1" - "github.com/golang/geo/s2" - - "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" - "github.com/dolthub/doltgresql/postgres/parser/geo/geoprojbase" -) - -// BestGeomProjection translates roughly to the ST_BestSRID function in PostGIS. -// It attempts to find the best projection for a bounding box into an accurate -// geometry-type projection. -// -// The algorithm is described by ST_Buffer/ST_Intersection documentation (paraphrased): -// It first determines the best SRID that fits the bounding box of the 2 geography objects (ST_Intersection only). -// It favors a north/south pole projection, then UTM, then LAEA for smaller zones, otherwise falling back -// to web mercator. -// If geography objects are within one half zone UTM but not the same UTM it will pick one of those. -// After the calculation is complete, it will fall back to WGS84 Geography. -func BestGeomProjection(boundingRect s2.Rect) (geoprojbase.Proj4Text, error) { - center := boundingRect.Center() - - latWidth := s1.Angle(boundingRect.Lat.Length()) - lngWidth := s1.Angle(boundingRect.Lng.Length()) - - // Check if these fit either the North Pole or South Pole areas. - // If the center has latitude greater than 70 (an arbitrary polar threshold), and it is - // within the polar ranges, return that. - if center.Lat.Degrees() > 70 && boundingRect.Lo().Lat.Degrees() > 45 { - // See: https://epsg.io/3574. - return getGeomProjection(3574) - } - // Same for south pole. - if center.Lat.Degrees() < -70 && boundingRect.Hi().Lat.Degrees() < -45 { - // See: https://epsg.io/3409 - return getGeomProjection(3409) - } - - // Each UTM zone is 6 degrees wide and distortion is low for geometries that fit within the zone. We use - // UTM if the width is lower than the UTM zone width, even though the geometry may span 2 zones -- using - // the geometry center to pick the UTM zone should result in most of the geometry being in the picked zone. - if lngWidth.Degrees() < 6 { - // Determine the offset of the projection. - // Offset longitude -180 to 0 and divide by 6 to get the zone. - // Note that we treat 180 degree longtitudes as offset 59. - // TODO(#geo): do we care about https://en.wikipedia.org/wiki/Universal_Transverse_Mercator_coordinate_system#Exceptions? - // PostGIS's _ST_BestSRID function doesn't seem to care: . - sridOffset := geopb.SRID(math.Min(math.Floor((center.Lng.Degrees()+180)/6), 59)) - if center.Lat.Degrees() >= 0 { - // Start at the north UTM SRID. - return getGeomProjection(32601 + sridOffset) - } - // Start at the south UTM SRID. - // This should make no difference in end result compared to using the north UTMs, - // but for completeness we do it. - return getGeomProjection(32701 + sridOffset) - } - - // Attempt to fit into LAEA areas if the width is less than 25 degrees (we can go up to 30 - // but want to leave some room for precision issues). - // - // LAEA areas are separated into 3 latitude zones between 0 and 90 and 3 latitude zones - // between -90 and 0. Within each latitude zones, they have different longitude bands: - // * The bands closest to the equator have 12x30 degree longitude zones. - // * The bands in the temperate area 8x45 degree longitude zones. - // * The bands near the poles have 4x90 degree longitude zones. - // - // For each of these bands, we custom define a LAEA area with the center of the LAEA area - // as the lat/lon offset. - // - // See also: https://en.wikipedia.org/wiki/Lambert_azimuthal_equal-area_projection. - if latWidth.Degrees() < 25 { - // Convert lat to a known 30 degree zone.. - // -3 represents [-90, -60), -2 represents [-60, -30) ... and 2 represents [60, 90]. - // (note: 90 is inclusive at the end). - // Treat a 90 degree latitude as band 2. - latZone := math.Min(math.Floor(center.Lat.Degrees()/30), 2) - latZoneCenterDegrees := (latZone * 30) + 15 - // Equator bands - 30 degree zones. - if (latZone == 0 || latZone == -1) && lngWidth.Degrees() <= 30 { - lngZone := math.Floor(center.Lng.Degrees() / 30) - return geoprojbase.MakeProj4Text( - fmt.Sprintf( - "+proj=laea +ellps=WGS84 +datum=WGS84 +lat_0=%g +lon_0=%g +units=m +no_defs", - latZoneCenterDegrees, - (lngZone*30)+15, - ), - ), nil - } - // Temperate bands - 45 degree zones. - if (latZone == -2 || latZone == 1) && lngWidth.Degrees() <= 45 { - lngZone := math.Floor(center.Lng.Degrees() / 45) - return geoprojbase.MakeProj4Text( - fmt.Sprintf( - "+proj=laea +ellps=WGS84 +datum=WGS84 +lat_0=%g +lon_0=%g +units=m +no_defs", - latZoneCenterDegrees, - (lngZone*45)+22.5, - ), - ), nil - } - // Polar bands -- 90 degree zones. - if (latZone == -3 || latZone == 2) && lngWidth.Degrees() <= 90 { - lngZone := math.Floor(center.Lng.Degrees() / 90) - return geoprojbase.MakeProj4Text( - fmt.Sprintf( - "+proj=laea +ellps=WGS84 +datum=WGS84 +lat_0=%g +lon_0=%g +units=m +no_defs", - latZoneCenterDegrees, - (lngZone*90)+45, - ), - ), nil - } - } - - // Default to Web Mercator. - return getGeomProjection(3857) -} - -// getGeomProjection returns the Proj4Text associated with an SRID. -func getGeomProjection(srid geopb.SRID) (geoprojbase.Proj4Text, error) { - proj, ok := geoprojbase.Projection(srid) - if !ok { - return geoprojbase.Proj4Text{}, errors.Newf("unexpected SRID %d", srid) - } - return proj.Proj4Text, nil -} diff --git a/postgres/parser/geo/geogfn/covers.go b/postgres/parser/geo/geogfn/covers.go deleted file mode 100644 index 06f4d2c04b..0000000000 --- a/postgres/parser/geo/geogfn/covers.go +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "fmt" - - "github.com/golang/geo/s2" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// Covers returns whether geography A covers geography B. -// -// This calculation is done on the sphere. -// -// Due to minor inaccuracies and lack of certain primitives in S2, -// precision for Covers will be for up to 1cm. -// -// Current limitations (which are also limitations in PostGIS): -// * POLYGON/LINESTRING only works as "contains" - if any point of the LINESTRING -// touches the boundary of the polygon, we will return false but should be true - e.g. -// SELECT st_covers( -// 'multipolygon(((0.0 0.0, 1.0 0.0, 1.0 1.0, 0.0 1.0, 0.0 0.0)), ((1.0 0.0, 2.0 0.0, 2.0 1.0, 1.0 1.0, 1.0 0.0)))', -// 'linestring(0.0 0.0, 1.0 0.0)'::geography -// ); -// -// * Furthermore, LINESTRINGS that are covered in multiple POLYGONs inside -// MULTIPOLYGON but NOT within a single POLYGON in the MULTIPOLYGON -// currently return false but should be true, e.g. -// SELECT st_covers( -// 'multipolygon(((0.0 0.0, 1.0 0.0, 1.0 1.0, 0.0 1.0, 0.0 0.0)), ((1.0 0.0, 2.0 0.0, 2.0 1.0, 1.0 1.0, 1.0 0.0)))', -// 'linestring(0.0 0.0, 2.0 0.0)'::geography -// ); -func Covers(a geo.Geography, b geo.Geography) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - return covers(a, b) -} - -// covers is the internal calculation for Covers. -func covers(a geo.Geography, b geo.Geography) (bool, error) { - // Rect "contains" is a version of covers. - if !a.BoundingRect().Contains(b.BoundingRect()) { - return false, nil - } - - // Ignore EMPTY regions in a. - aRegions, err := a.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return false, err - } - // If any of b is empty, we cannot cover it. Error and catch to return false. - bRegions, err := b.AsS2(geo.EmptyBehaviorError) - if err != nil { - if geo.IsEmptyGeometryError(err) { - return false, nil - } - return false, err - } - - // We need to check each region in B is covered by at least - // one region of A. - bRegionsRemaining := make(map[int]struct{}, len(bRegions)) - for i := range bRegions { - bRegionsRemaining[i] = struct{}{} - } - for _, aRegion := range aRegions { - for bRegionIdx := range bRegionsRemaining { - regionCovers, err := regionCovers(aRegion, bRegions[bRegionIdx]) - if err != nil { - return false, err - } - if regionCovers { - delete(bRegionsRemaining, bRegionIdx) - } - } - if len(bRegionsRemaining) == 0 { - return true, nil - } - } - return false, nil -} - -// CoveredBy returns whether geography A is covered by geography B. -// See Covers for limitations. -func CoveredBy(a geo.Geography, b geo.Geography) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - return covers(b, a) -} - -// regionCovers returns whether aRegion completely covers bRegion. -func regionCovers(aRegion s2.Region, bRegion s2.Region) (bool, error) { - switch aRegion := aRegion.(type) { - case s2.Point: - switch bRegion := bRegion.(type) { - case s2.Point: - return aRegion.ContainsPoint(bRegion), nil - case *s2.Polyline: - return false, nil - case *s2.Polygon: - return false, nil - default: - return false, fmt.Errorf("unknown s2 type of b: %#v", bRegion) - } - case *s2.Polyline: - switch bRegion := bRegion.(type) { - case s2.Point: - return polylineCoversPoint(aRegion, bRegion), nil - case *s2.Polyline: - return polylineCoversPolyline(aRegion, bRegion), nil - case *s2.Polygon: - return false, nil - default: - return false, fmt.Errorf("unknown s2 type of b: %#v", bRegion) - } - case *s2.Polygon: - switch bRegion := bRegion.(type) { - case s2.Point: - return polygonCoversPoint(aRegion, bRegion), nil - case *s2.Polyline: - return polygonCoversPolyline(aRegion, bRegion), nil - case *s2.Polygon: - return polygonCoversPolygon(aRegion, bRegion), nil - default: - return false, fmt.Errorf("unknown s2 type of b: %#v", bRegion) - } - } - return false, fmt.Errorf("unknown s2 type of a: %#v", aRegion) -} - -// polylineCoversPoints returns whether a polyline covers a given point. -func polylineCoversPoint(a *s2.Polyline, b s2.Point) bool { - return a.IntersectsCell(s2.CellFromPoint(b)) -} - -// polylineCoversPointsWithIdx returns whether a polyline covers a given point. -// If true, it will also return an index of the start of the edge where there -// was an intersection. -func polylineCoversPointWithIdx(a *s2.Polyline, b s2.Point) (bool, int) { - for edgeIdx := 0; edgeIdx < a.NumEdges(); edgeIdx++ { - if edgeCoversPoint(a.Edge(edgeIdx), b) { - return true, edgeIdx - } - } - return false, -1 -} - -// polygonCoversPoints returns whether a polygon covers a given point. -func polygonCoversPoint(a *s2.Polygon, b s2.Point) bool { - return a.IntersectsCell(s2.CellFromPoint(b)) -} - -// edgeCoversPoint determines whether a given edge contains a point. -func edgeCoversPoint(e s2.Edge, p s2.Point) bool { - return (&s2.Polyline{e.V0, e.V1}).IntersectsCell(s2.CellFromPoint(p)) -} - -// polylineCoversPolyline returns whether polyline a covers polyline b. -func polylineCoversPolyline(a *s2.Polyline, b *s2.Polyline) bool { - if polylineCoversPolylineOrdered(a, b) { - return true - } - // Check reverse ordering works as well. - reversedB := make([]s2.Point, len(*b)) - for i, point := range *b { - reversedB[len(reversedB)-1-i] = point - } - newBAsPolyline := s2.Polyline(reversedB) - return polylineCoversPolylineOrdered(a, &newBAsPolyline) -} - -// polylineCoversPolylineOrdered returns whether a polyline covers a polyline -// in the same ordering. -func polylineCoversPolylineOrdered(a *s2.Polyline, b *s2.Polyline) bool { - aCoversStartOfB, aCoverBStart := polylineCoversPointWithIdx(a, (*b)[0]) - // We must first check that the start of B is contained by A. - if !aCoversStartOfB { - return false - } - - aPoints := *a - bPoints := *b - // We have found "aIdx" which is the first edge in polyline A - // that includes the starting vertex of polyline "B". - // Start checking the covering from this edge. - aIdx := aCoverBStart - bIdx := 0 - - aEdge := s2.Edge{V0: aPoints[aIdx], V1: aPoints[aIdx+1]} - bEdge := s2.Edge{V0: bPoints[bIdx], V1: bPoints[bIdx+1]} - for { - aEdgeCoversBStart := edgeCoversPoint(aEdge, bEdge.V0) - aEdgeCoversBEnd := edgeCoversPoint(aEdge, bEdge.V1) - bEdgeCoversAEnd := edgeCoversPoint(bEdge, aEdge.V1) - if aEdgeCoversBStart && aEdgeCoversBEnd { - // If the edge A fully covers edge B, check the next edge. - bIdx++ - // We are out of edges in B, and A keeps going or stops at the same point. - // This is a covering. - if bIdx == len(bPoints)-1 { - return true - } - bEdge = s2.Edge{V0: bPoints[bIdx], V1: bPoints[bIdx+1]} - // If A and B end at the same place, we need to move A forward. - if bEdgeCoversAEnd { - aIdx++ - if aIdx == len(aPoints)-1 { - // At this point, B extends past A. return false. - return false - } - aEdge = s2.Edge{V0: aPoints[aIdx], V1: aPoints[aIdx+1]} - } - continue - } - - if aEdgeCoversBStart { - // Edge A doesn't cover the end of B, but it does cover the start. - // If B doesn't cover the end of A, we're done. - if !bEdgeCoversAEnd { - return false - } - // If the end of edge B covers the end of A, that means that - // B is possibly longer than A. If that's the case, truncate B - // to be the end of A, and move A forward. - bEdge.V0 = aEdge.V1 - aIdx++ - if aIdx == len(aPoints)-1 { - // At this point, B extends past A. return false. - return false - } - aEdge = s2.Edge{V0: aPoints[aIdx], V1: aPoints[aIdx+1]} - continue - } - - // Otherwise, we're doomed. - // Edge A does not contain edge B. - return false - } -} - -// polygonCoversPolyline returns whether polygon a covers polyline b. -func polygonCoversPolyline(a *s2.Polygon, b *s2.Polyline) bool { - // Check everything of polyline B is in the interior of polygon A. - for _, vertex := range *b { - if !polygonCoversPoint(a, vertex) { - return false - } - } - // Even if every point of polyline B is inside polygon A, they - // may form an edge which goes outside of polygon A and back in - // due to holes and concavities. - // - // As such, check if there are any intersections - if so, - // we do not consider it a covering. - // - // NOTE: this implementation has a limitation where a vertex of the line could - // be on the boundary and still technically be "covered" (using GEOS). - // - // However, PostGIS seems to consider this as non-covering so we can go - // with this for now. - // i.e. ` - // select st_covers( - // 'POLYGON((0.0 0.0, 1.0 0.0, 1.0 1.0, 0.0 1.0, 0.0 0.0))'::geography, - // 'LINESTRING(0.0 0.0, 1.0 1.0)'::geography); - // ` returns false, but should be true. This requires some more math to resolve. - return !polygonIntersectsPolylineEdge(a, b) -} - -// polygonIntersectsPolylineEdge returns whether polygon a intersects with -// polyline b by edge. It does not return true if the polyline is completely -// within the polygon. -func polygonIntersectsPolylineEdge(a *s2.Polygon, b *s2.Polyline) bool { - // Avoid using NumEdges / Edge of the Polygon type as it is not O(1). - for _, loop := range a.Loops() { - for loopEdgeIdx := 0; loopEdgeIdx < loop.NumEdges(); loopEdgeIdx++ { - loopEdge := loop.Edge(loopEdgeIdx) - crosser := s2.NewChainEdgeCrosser(loopEdge.V0, loopEdge.V1, (*b)[0]) - for _, nextVertex := range (*b)[1:] { - if crosser.ChainCrossingSign(nextVertex) != s2.DoNotCross { - return true - } - } - } - } - return false -} - -// polygonCoversPolygon returns whether polygon a intersects with polygon b. -func polygonCoversPolygon(a *s2.Polygon, b *s2.Polygon) bool { - // We can rely on Contains here, as if the boundaries of A and B are on top - // of each other, it is still considered a containment as well as a covering. - return a.Contains(b) -} diff --git a/postgres/parser/geo/geogfn/distance.go b/postgres/parser/geo/geogfn/distance.go deleted file mode 100644 index 7f54816a8a..0000000000 --- a/postgres/parser/geo/geogfn/distance.go +++ /dev/null @@ -1,426 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/golang/geo/s1" - "github.com/golang/geo/s2" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geodist" - "github.com/dolthub/doltgresql/postgres/parser/geo/geographiclib" -) - -// SpheroidErrorFraction is an error fraction to compensate for using a sphere -// to calculate the distance for what is actually a spheroid. The distance -// calculation has an error that is bounded by (2 * spheroid.Flattening)%. -// This 5% margin is pretty safe. -const SpheroidErrorFraction = 0.05 - -// Distance returns the distance between geographies a and b on a sphere or spheroid. -// Returns a geo.EmptyGeometryError if any of the Geographies are EMPTY. -func Distance( - a geo.Geography, b geo.Geography, useSphereOrSpheroid UseSphereOrSpheroid, -) (float64, error) { - if a.SRID() != b.SRID() { - return 0, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - - aRegions, err := a.AsS2(geo.EmptyBehaviorError) - if err != nil { - return 0, err - } - bRegions, err := b.AsS2(geo.EmptyBehaviorError) - if err != nil { - return 0, err - } - spheroid, err := a.Spheroid() - if err != nil { - return 0, err - } - return distanceGeographyRegions( - spheroid, - useSphereOrSpheroid, - aRegions, - bRegions, - a.BoundingRect().Intersects(b.BoundingRect()), - 0, /* stopAfter */ - geo.FnInclusive, - ) -} - -// -// Spheroids -// - -// s2GeodistLineString implements geodist.LineString. -type s2GeodistLineString struct { - *s2.Polyline -} - -var _ geodist.LineString = (*s2GeodistLineString)(nil) - -// IsShape implements the geodist.LineString interface. -func (*s2GeodistLineString) IsShape() {} - -// LineString implements the geodist.LineString interface. -func (*s2GeodistLineString) IsLineString() {} - -// Edge implements the geodist.LineString interface. -func (g *s2GeodistLineString) Edge(i int) geodist.Edge { - return geodist.Edge{ - V0: geodist.Point{GeogPoint: (*g.Polyline)[i]}, - V1: geodist.Point{GeogPoint: (*g.Polyline)[i+1]}, - } -} - -// NumEdges implements the geodist.LineString interface. -func (g *s2GeodistLineString) NumEdges() int { - return len(*g.Polyline) - 1 -} - -// Vertex implements the geodist.LineString interface. -func (g *s2GeodistLineString) Vertex(i int) geodist.Point { - return geodist.Point{ - GeogPoint: (*g.Polyline)[i], - } -} - -// NumVertexes implements the geodist.LineString interface. -func (g *s2GeodistLineString) NumVertexes() int { - return len(*g.Polyline) -} - -// s2GeodistLinearRing implements geodist.LinearRing. -type s2GeodistLinearRing struct { - *s2.Loop -} - -var _ geodist.LinearRing = (*s2GeodistLinearRing)(nil) - -// IsShape implements the geodist.LinearRing interface. -func (*s2GeodistLinearRing) IsShape() {} - -// LinearRing implements the geodist.LinearRing interface. -func (*s2GeodistLinearRing) IsLinearRing() {} - -// Edge implements the geodist.LinearRing interface. -func (g *s2GeodistLinearRing) Edge(i int) geodist.Edge { - return geodist.Edge{ - V0: geodist.Point{GeogPoint: g.Loop.Vertex(i)}, - V1: geodist.Point{GeogPoint: g.Loop.Vertex(i + 1)}, - } -} - -// NumEdges implements the geodist.LinearRing interface. -func (g *s2GeodistLinearRing) NumEdges() int { - return g.Loop.NumEdges() -} - -// Vertex implements the geodist.LinearRing interface. -func (g *s2GeodistLinearRing) Vertex(i int) geodist.Point { - return geodist.Point{ - GeogPoint: g.Loop.Vertex(i), - } -} - -// NumVertexes implements the geodist.LinearRing interface. -func (g *s2GeodistLinearRing) NumVertexes() int { - return g.Loop.NumVertices() -} - -// s2GeodistPolygon implements geodist.Polygon. -type s2GeodistPolygon struct { - *s2.Polygon -} - -var _ geodist.Polygon = (*s2GeodistPolygon)(nil) - -// IsShape implements the geodist.Polygon interface. -func (*s2GeodistPolygon) IsShape() {} - -// Polygon implements the geodist.Polygon interface. -func (*s2GeodistPolygon) IsPolygon() {} - -// LinearRing implements the geodist.Polygon interface. -func (g *s2GeodistPolygon) LinearRing(i int) geodist.LinearRing { - return &s2GeodistLinearRing{Loop: g.Polygon.Loop(i)} -} - -// NumLinearRings implements the geodist.Polygon interface. -func (g *s2GeodistPolygon) NumLinearRings() int { - return g.Polygon.NumLoops() -} - -// s2GeodistEdgeCrosser implements geodist.EdgeCrosser. -type s2GeodistEdgeCrosser struct { - *s2.EdgeCrosser -} - -var _ geodist.EdgeCrosser = (*s2GeodistEdgeCrosser)(nil) - -// ChainCrossing implements geodist.EdgeCrosser. -func (c *s2GeodistEdgeCrosser) ChainCrossing(p geodist.Point) (bool, geodist.Point) { - // Returns nil for the intersection point as we don't require the intersection - // point as we do not have to implement ShortestLine in geography. - return c.EdgeCrosser.ChainCrossingSign(p.GeogPoint) != s2.DoNotCross, geodist.Point{} -} - -// distanceGeographyRegions calculates the distance between two sets of regions. -// If inclusive, it will quit if it finds a distance that is less than or equal -// to stopAfter. Otherwise, it will quit if a distance less than stopAfter is -// found. It is not guaranteed to find the absolute minimum distance if -// stopAfter > 0. -// -// !!! SURPRISING BEHAVIOR WARNING FOR SPHEROIDS !!! -// PostGIS evaluates the distance between spheroid regions by computing the min of -// the pair-wise distance between the cross-product of the regions in A and the regions -// in B, where the pair-wise distance is computed as: -// * Find the two closest points between the pairs of regions using the sphere -// for distance calculations. -// * Compute the spheroid distance between the two closest points. -// -// This is technically incorrect, since it is possible that the two closest points on -// the spheroid are different than the two closest points on the sphere. -// See distance_test.go for examples of the "truer" distance values. -// Since we aim to be compatible with PostGIS, we adopt the same approach. -func distanceGeographyRegions( - spheroid *geographiclib.Spheroid, - useSphereOrSpheroid UseSphereOrSpheroid, - aRegions []s2.Region, - bRegions []s2.Region, - boundingBoxIntersects bool, - stopAfter float64, - exclusivity geo.FnExclusivity, -) (float64, error) { - minDistance := math.MaxFloat64 - for _, aRegion := range aRegions { - aGeodist, err := regionToGeodistShape(aRegion) - if err != nil { - return 0, err - } - for _, bRegion := range bRegions { - minDistanceUpdater := newGeographyMinDistanceUpdater( - spheroid, - useSphereOrSpheroid, - stopAfter, - exclusivity, - ) - bGeodist, err := regionToGeodistShape(bRegion) - if err != nil { - return 0, err - } - earlyExit, err := geodist.ShapeDistance( - &geographyDistanceCalculator{ - updater: minDistanceUpdater, - boundingBoxIntersects: boundingBoxIntersects, - }, - aGeodist, - bGeodist, - ) - if err != nil { - return 0, err - } - minDistance = math.Min(minDistance, minDistanceUpdater.Distance()) - if earlyExit { - return minDistance, nil - } - } - } - return minDistance, nil -} - -// geographyMinDistanceUpdater finds the minimum distance using a sphere. -// Methods will return early if it finds a minimum distance <= stopAfterLE. -type geographyMinDistanceUpdater struct { - spheroid *geographiclib.Spheroid - useSphereOrSpheroid UseSphereOrSpheroid - minEdge s2.Edge - minD s1.ChordAngle - stopAfter s1.ChordAngle - exclusivity geo.FnExclusivity -} - -var _ geodist.DistanceUpdater = (*geographyMinDistanceUpdater)(nil) - -// newGeographyMinDistanceUpdater returns a new geographyMinDistanceUpdater with the -// correct arguments set up. -func newGeographyMinDistanceUpdater( - spheroid *geographiclib.Spheroid, - useSphereOrSpheroid UseSphereOrSpheroid, - stopAfter float64, - exclusivity geo.FnExclusivity, -) *geographyMinDistanceUpdater { - multiplier := 1.0 - if useSphereOrSpheroid == UseSpheroid { - // Modify the stopAfterLE distance to be less by the error fraction, since - // we use the sphere to calculate the distance and we want to leave a - // buffer for spheroid distances being slightly off. - multiplier -= SpheroidErrorFraction - } - stopAfterChordAngle := s1.ChordAngleFromAngle(s1.Angle(stopAfter * multiplier / spheroid.SphereRadius)) - return &geographyMinDistanceUpdater{ - spheroid: spheroid, - minD: math.MaxFloat64, - useSphereOrSpheroid: useSphereOrSpheroid, - stopAfter: stopAfterChordAngle, - exclusivity: exclusivity, - } -} - -// Distance implements the DistanceUpdater interface. -func (u *geographyMinDistanceUpdater) Distance() float64 { - // If the distance is zero, avoid the call to spheroidDistance and return early. - if u.minD == 0 { - return 0 - } - if u.useSphereOrSpheroid == UseSpheroid { - return spheroidDistance(u.spheroid, u.minEdge.V0, u.minEdge.V1) - } - return u.minD.Angle().Radians() * u.spheroid.SphereRadius -} - -// Update implements the geodist.DistanceUpdater interface. -func (u *geographyMinDistanceUpdater) Update(aPoint geodist.Point, bPoint geodist.Point) bool { - a := aPoint.GeogPoint - b := bPoint.GeogPoint - - sphereDistance := s2.ChordAngleBetweenPoints(a, b) - if sphereDistance < u.minD { - u.minD = sphereDistance - u.minEdge = s2.Edge{V0: a, V1: b} - // If we have a threshold, determine if we can stop early. - // If the sphere distance is within range of the stopAfter, we can - // definitively say we've reach the close enough point. - if (u.exclusivity == geo.FnInclusive && u.minD <= u.stopAfter) || - (u.exclusivity == geo.FnExclusive && u.minD < u.stopAfter) { - return true - } - } - return false -} - -// OnIntersects implements the geodist.DistanceUpdater interface. -func (u *geographyMinDistanceUpdater) OnIntersects(p geodist.Point) bool { - u.minD = 0 - return true -} - -// IsMaxDistance implements the geodist.DistanceUpdater interface. -func (u *geographyMinDistanceUpdater) IsMaxDistance() bool { - return false -} - -// FlipGeometries implements the geodist.DistanceUpdater interface. -func (u *geographyMinDistanceUpdater) FlipGeometries() { - // FlipGeometries is unimplemented for geographyMinDistanceUpdater as we don't - // require the order of geometries for calculation of minimum distance. -} - -// geographyDistanceCalculator implements geodist.DistanceCalculator -type geographyDistanceCalculator struct { - updater *geographyMinDistanceUpdater - boundingBoxIntersects bool -} - -var _ geodist.DistanceCalculator = (*geographyDistanceCalculator)(nil) - -// DistanceUpdater implements geodist.DistanceCalculator. -func (c *geographyDistanceCalculator) DistanceUpdater() geodist.DistanceUpdater { - return c.updater -} - -// BoundingBoxIntersects implements geodist.DistanceCalculator. -func (c *geographyDistanceCalculator) BoundingBoxIntersects() bool { - return c.boundingBoxIntersects -} - -// NewEdgeCrosser implements geodist.DistanceCalculator. -func (c *geographyDistanceCalculator) NewEdgeCrosser( - edge geodist.Edge, startPoint geodist.Point, -) geodist.EdgeCrosser { - return &s2GeodistEdgeCrosser{ - EdgeCrosser: s2.NewChainEdgeCrosser( - edge.V0.GeogPoint, - edge.V1.GeogPoint, - startPoint.GeogPoint, - ), - } -} - -// PointInLinearRing implements geodist.DistanceCalculator. -func (c *geographyDistanceCalculator) PointInLinearRing( - point geodist.Point, polygon geodist.LinearRing, -) bool { - return polygon.(*s2GeodistLinearRing).ContainsPoint(point.GeogPoint) -} - -// ClosestPointToEdge implements geodist.DistanceCalculator. -// -// ClosestPointToEdge projects the point onto the infinite line represented -// by the edge. This will return the point on the line closest to the edge. -// It will return the closest point on the line, as well as a bool representing -// whether the point that is projected lies directly on the edge as a segment. -// -// For visualization and more, see: Section 6 / Figure 4 of -// "Projective configuration theorems: old wine into new wineskins", Tabachnikov, Serge, 2016/07/16 -func (c *geographyDistanceCalculator) ClosestPointToEdge( - edge geodist.Edge, point geodist.Point, -) (geodist.Point, bool) { - eV0 := edge.V0.GeogPoint - eV1 := edge.V1.GeogPoint - - // Project the point onto the normal of the edge. A great circle passing through - // the normal and the point will intersect with the great circle represented - // by the given edge. - normal := eV0.Vector.Cross(eV1.Vector).Normalize() - // To find the point where the great circle represented by the edge and the - // great circle represented by (normal, point), we project the point - // onto the normal. - normalScaledToPoint := normal.Mul(normal.Dot(point.GeogPoint.Vector)) - // The difference between the point and the projection of the normal when normalized - // should give us a point on the great circle which contains the vertexes of the edge. - closestPoint := s2.Point{Vector: point.GeogPoint.Vector.Sub(normalScaledToPoint).Normalize()} - // We then check whether the given point lies on the geodesic of the edge, - // as the above algorithm only generates a point on the great circle - // represented by the edge. - return geodist.Point{GeogPoint: closestPoint}, (&s2.Polyline{eV0, eV1}).IntersectsCell(s2.CellFromPoint(closestPoint)) -} - -// regionToGeodistShape converts the s2 Region to a geodist object. -func regionToGeodistShape(r s2.Region) (geodist.Shape, error) { - switch r := r.(type) { - case s2.Point: - return &geodist.Point{GeogPoint: r}, nil - case *s2.Polyline: - return &s2GeodistLineString{Polyline: r}, nil - case *s2.Polygon: - return &s2GeodistPolygon{Polygon: r}, nil - } - return nil, errors.Newf("unknown region: %T", r) -} diff --git a/postgres/parser/geo/geogfn/dwithin.go b/postgres/parser/geo/geogfn/dwithin.go deleted file mode 100644 index 86589018dd..0000000000 --- a/postgres/parser/geo/geogfn/dwithin.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "github.com/cockroachdb/errors" - "github.com/golang/geo/s1" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// DWithin returns whether a is within distance d of b. If A or B contains empty -// Geography objects, this will return false. If inclusive, DWithin is -// equivalent to Distance(a, b) <= d. Otherwise, DWithin is instead equivalent -// to Distance(a, b) < d. -func DWithin( - a geo.Geography, - b geo.Geography, - distance float64, - useSphereOrSpheroid UseSphereOrSpheroid, - exclusivity geo.FnExclusivity, -) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if distance < 0 { - return false, errors.Newf("dwithin distance cannot be less than zero") - } - spheroid, err := a.Spheroid() - if err != nil { - return false, err - } - - angleToExpand := s1.Angle(distance / spheroid.SphereRadius) - if useSphereOrSpheroid == UseSpheroid { - angleToExpand *= (1 + SpheroidErrorFraction) - } - if !a.BoundingCap().Expanded(angleToExpand).Intersects(b.BoundingCap()) { - return false, nil - } - - aRegions, err := a.AsS2(geo.EmptyBehaviorError) - if err != nil { - if geo.IsEmptyGeometryError(err) { - return false, nil - } - return false, err - } - bRegions, err := b.AsS2(geo.EmptyBehaviorError) - if err != nil { - if geo.IsEmptyGeometryError(err) { - return false, nil - } - return false, err - } - maybeClosestDistance, err := distanceGeographyRegions( - spheroid, - useSphereOrSpheroid, - aRegions, - bRegions, - a.BoundingRect().Intersects(b.BoundingRect()), - distance, - exclusivity, - ) - if err != nil { - return false, err - } - if exclusivity == geo.FnExclusive { - return maybeClosestDistance < distance, nil - } - return maybeClosestDistance <= distance, nil -} diff --git a/postgres/parser/geo/geogfn/geogfn.go b/postgres/parser/geo/geogfn/geogfn.go deleted file mode 100644 index 404e235d64..0000000000 --- a/postgres/parser/geo/geogfn/geogfn.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -// UseSphereOrSpheroid indicates whether to use a Sphere or Spheroid -// for certain calculations. -type UseSphereOrSpheroid bool - -const ( - // UseSpheroid indicates to use the spheroid for calculations. - UseSpheroid UseSphereOrSpheroid = true - // UseSphere indicates to use the sphere for calculations. - UseSphere UseSphereOrSpheroid = false -) diff --git a/postgres/parser/geo/geogfn/geographiclib.go b/postgres/parser/geo/geogfn/geographiclib.go deleted file mode 100644 index bcd0ff5be9..0000000000 --- a/postgres/parser/geo/geogfn/geographiclib.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "github.com/golang/geo/s2" - - "github.com/dolthub/doltgresql/postgres/parser/geo/geographiclib" -) - -// spheroidDistance returns the s12 (meter) component of spheroid.Inverse from s2 Points. -func spheroidDistance(s *geographiclib.Spheroid, a s2.Point, b s2.Point) float64 { - inv, _, _ := s.Inverse(s2.LatLngFromPoint(a), s2.LatLngFromPoint(b)) - return inv -} diff --git a/postgres/parser/geo/geogfn/intersects.go b/postgres/parser/geo/geogfn/intersects.go deleted file mode 100644 index 86956329f4..0000000000 --- a/postgres/parser/geo/geogfn/intersects.go +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "fmt" - - "github.com/golang/geo/s2" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// Intersects returns whether geography A intersects geography B. -// This calculation is done on the sphere. -// Precision of intersect measurements is up to 1cm. -func Intersects(a geo.Geography, b geo.Geography) (bool, error) { - if !a.BoundingRect().Intersects(b.BoundingRect()) { - return false, nil - } - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - - aRegions, err := a.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return false, err - } - bRegions, err := b.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return false, err - } - // If any of aRegions intersects any of bRegions, return true. - for _, aRegion := range aRegions { - for _, bRegion := range bRegions { - intersects, err := singleRegionIntersects(aRegion, bRegion) - if err != nil { - return false, err - } - if intersects { - return true, nil - } - } - } - return false, nil -} - -// singleRegionIntersects returns true if aRegion intersects bRegion. -func singleRegionIntersects(aRegion s2.Region, bRegion s2.Region) (bool, error) { - switch aRegion := aRegion.(type) { - case s2.Point: - switch bRegion := bRegion.(type) { - case s2.Point: - return aRegion.IntersectsCell(s2.CellFromPoint(bRegion)), nil - case *s2.Polyline: - return bRegion.IntersectsCell(s2.CellFromPoint(aRegion)), nil - case *s2.Polygon: - return bRegion.IntersectsCell(s2.CellFromPoint(aRegion)), nil - default: - return false, fmt.Errorf("unknown s2 type of b: %#v", bRegion) - } - case *s2.Polyline: - switch bRegion := bRegion.(type) { - case s2.Point: - return aRegion.IntersectsCell(s2.CellFromPoint(bRegion)), nil - case *s2.Polyline: - return polylineIntersectsPolyline(aRegion, bRegion), nil - case *s2.Polygon: - return polygonIntersectsPolyline(bRegion, aRegion), nil - default: - return false, fmt.Errorf("unknown s2 type of b: %#v", bRegion) - } - case *s2.Polygon: - switch bRegion := bRegion.(type) { - case s2.Point: - return aRegion.IntersectsCell(s2.CellFromPoint(bRegion)), nil - case *s2.Polyline: - return polygonIntersectsPolyline(aRegion, bRegion), nil - case *s2.Polygon: - return aRegion.Intersects(bRegion), nil - default: - return false, fmt.Errorf("unknown s2 type of b: %#v", bRegion) - } - } - return false, fmt.Errorf("unknown s2 type of a: %#v", aRegion) -} - -// polylineIntersectsPolyline returns whether polyline a intersects with -// polyline b. -func polylineIntersectsPolyline(a *s2.Polyline, b *s2.Polyline) bool { - for aEdgeIdx := 0; aEdgeIdx < a.NumEdges(); aEdgeIdx++ { - edge := a.Edge(aEdgeIdx) - crosser := s2.NewChainEdgeCrosser(edge.V0, edge.V1, (*b)[0]) - for _, nextVertex := range (*b)[1:] { - crossing := crosser.ChainCrossingSign(nextVertex) - if crossing != s2.DoNotCross { - return true - } - } - } - return false -} - -// polygonIntersectsPolyline returns whether polygon a intersects with -// polyline b. -func polygonIntersectsPolyline(a *s2.Polygon, b *s2.Polyline) bool { - // Check if the polygon contains any vertex of the line b. - for _, vertex := range *b { - if a.IntersectsCell(s2.CellFromPoint(vertex)) { - return true - } - } - // Here the polygon does not contain any vertex of the polyline. - // The polyline can intersect the polygon if a line goes through the polygon - // with both vertexes that are not in the interior of the polygon. - // This technique works for holes touching, or holes touching the exterior - // as the point in which the holes touch is considered an intersection. - for _, loop := range a.Loops() { - for loopEdgeIdx := 0; loopEdgeIdx < loop.NumEdges(); loopEdgeIdx++ { - loopEdge := loop.Edge(loopEdgeIdx) - crosser := s2.NewChainEdgeCrosser(loopEdge.V0, loopEdge.V1, (*b)[0]) - for _, nextVertex := range (*b)[1:] { - crossing := crosser.ChainCrossingSign(nextVertex) - if crossing != s2.DoNotCross { - return true - } - } - } - } - return false -} diff --git a/postgres/parser/geo/geogfn/segmentize.go b/postgres/parser/geo/geogfn/segmentize.go deleted file mode 100644 index 82bf3895c1..0000000000 --- a/postgres/parser/geo/geogfn/segmentize.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/golang/geo/s2" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geosegmentize" -) - -// Segmentize return modified Geography having no segment longer -// that given maximum segment length. -// This works by dividing each segment by a power of 2 to find the -// smallest power less than or equal to the segmentMaxLength. -func Segmentize(geography geo.Geography, segmentMaxLength float64) (geo.Geography, error) { - geometry, err := geography.AsGeomT() - if err != nil { - return geo.Geography{}, err - } - switch geometry := geometry.(type) { - case *geom.Point, *geom.MultiPoint: - return geography, nil - default: - if segmentMaxLength <= 0 { - return geo.Geography{}, errors.Newf("maximum segment length must be positive") - } - spheroid, err := geography.Spheroid() - if err != nil { - return geo.Geography{}, err - } - // Convert segmentMaxLength to Angle with respect to earth sphere as - // further calculation is done considering segmentMaxLength as Angle. - segmentMaxAngle := segmentMaxLength / spheroid.SphereRadius - segGeometry, err := geosegmentize.SegmentizeGeom(geometry, segmentMaxAngle, segmentizeCoords) - if err != nil { - return geo.Geography{}, err - } - return geo.MakeGeographyFromGeomT(segGeometry) - } -} - -// segmentizeCoords inserts multiple points between given two coordinates and -// return resultant point as flat []float64. Such that distance between any two -// points is less than given maximum segment's length, the total number of -// segments is the power of 2, and all the segments are of the same length. -// Note: List of points does not consist of end point. -func segmentizeCoords(a geom.Coord, b geom.Coord, segmentMaxAngle float64) ([]float64, error) { - // Converted geom.Coord into s2.Point so we can segmentize the coordinates. - pointA := s2.PointFromLatLng(s2.LatLngFromDegrees(a.Y(), a.X())) - pointB := s2.PointFromLatLng(s2.LatLngFromDegrees(b.Y(), b.X())) - - chordAngleBetweenPoints := s2.ChordAngleBetweenPoints(pointA, pointB).Angle().Radians() - // PostGIS' behavior appears to involve cutting this down into segments divisible - // by a power of two. As such, we do not use ceil(chordAngleBetweenPoints/segmentMaxAngle). - // - // This calculation is to determine the total number of segment between given - // 2 coordinates, ensuring that the segments are divided into parts divisible by - // a power of 2. - // - // For that fraction by segment must be less than or equal to - // the fraction of max segment length to distance between point, since the - // total number of segment must be power of 2 therefore we can write as - // 1 / (2^n)[numberOfSegmentsToCreate] <= segmentMaxLength / distanceBetweenPoints < 1 / (2^(n-1)) - // (2^n)[numberOfSegmentsToCreate] >= distanceBetweenPoints / segmentMaxLength > 2^(n-1) - // therefore n = ceil(log2(segmentMaxLength/distanceBetweenPoints)). Hence - // numberOfSegmentsToCreate = 2^(ceil(log2(segmentMaxLength/distanceBetweenPoints))). - numberOfSegmentsToCreate := int(math.Pow(2, math.Ceil(math.Log2(chordAngleBetweenPoints/segmentMaxAngle)))) - numPoints := 2 * (1 + numberOfSegmentsToCreate) - if numPoints > geosegmentize.MaxPoints { - return nil, errors.Newf( - "attempting to segmentize into too many coordinates; need %d points between %v and %v, max %d", - numPoints, - a, - b, - geosegmentize.MaxPoints, - ) - } - allSegmentizedCoordinates := make([]float64, 0, numPoints) - allSegmentizedCoordinates = append(allSegmentizedCoordinates, a.X(), a.Y()) - for pointInserted := 1; pointInserted < numberOfSegmentsToCreate; pointInserted++ { - newPoint := s2.Interpolate(float64(pointInserted)/float64(numberOfSegmentsToCreate), pointA, pointB) - latLng := s2.LatLngFromPoint(newPoint) - allSegmentizedCoordinates = append(allSegmentizedCoordinates, latLng.Lng.Degrees(), latLng.Lat.Degrees()) - } - return allSegmentizedCoordinates, nil -} diff --git a/postgres/parser/geo/geogfn/topology_operations.go b/postgres/parser/geo/geogfn/topology_operations.go deleted file mode 100644 index eb166a5599..0000000000 --- a/postgres/parser/geo/geogfn/topology_operations.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "github.com/cockroachdb/errors" - "github.com/golang/geo/r3" - "github.com/golang/geo/s2" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// Centroid returns the Centroid of a given Geography. -// -// NOTE: In the case of (Multi)Polygon Centroid result, it doesn't mirror with -// PostGIS's result. We are using the same algorithm of dividing into triangles. -// However, The PostGIS implementation differs as it cuts triangles in a different -// way - namely, it fixes the first point in the exterior ring as the first point -// of the triangle, whereas we always update the reference point to be -// the first point of the ring when moving from one ring to another. -// -// See: http://jennessent.com/downloads/Graphics_Shapes_Manual_A4.pdf#page=49 -// for more details. -// -// Ideally, both implementations should provide the same result. However, the -// centroid of the triangles is the vectorized mean of all the points, not the -// actual projection in the Spherical surface, which causes a small inaccuracies. -// This inaccuracy will eventually grow if there is a substantial -// number of a triangle with a larger area. -func Centroid(g geo.Geography, useSphereOrSpheroid UseSphereOrSpheroid) (geo.Geography, error) { - geomRepr, err := g.AsGeomT() - if err != nil { - return geo.Geography{}, err - } - if geomRepr.Empty() { - return geo.MakeGeographyFromGeomT(geom.NewGeometryCollection().SetSRID(geomRepr.SRID())) - } - switch geomRepr.(type) { - case *geom.Point, *geom.LineString, *geom.Polygon, *geom.MultiPoint, *geom.MultiLineString, *geom.MultiPolygon: - default: - return geo.Geography{}, errors.Newf("unhandled geography type %s", g.ShapeType().String()) - } - - regions, err := geo.S2RegionsFromGeomT(geomRepr, geo.EmptyBehaviorOmit) - if err != nil { - return geo.Geography{}, err - } - spheroid, err := g.Spheroid() - if err != nil { - return geo.Geography{}, err - } - - // localWeightedCentroids is the collection of all the centroid corresponds to - // various small regions in which we divide the given region for calculation - // of centroid. The magnitude of each s2.Point.Vector represents - // the weight corresponding to its region. - var localWeightedCentroids []s2.Point - for _, region := range regions { - switch region := region.(type) { - case s2.Point: - localWeightedCentroids = append(localWeightedCentroids, region) - case *s2.Polyline: - // The algorithm used for the calculation of centroid for (Multi)LineString: - // * Split (Multi)LineString in the set of individual edges. - // * Calculate the mid-points and length/angle for all the edges. - // * The centroid of (Multi)LineString will be a weighted average of mid-points - // of all the edges, where each mid-points is weighted by its length/angle. - for edgeIdx := 0; edgeIdx < region.NumEdges(); edgeIdx++ { - var edgeWeight float64 - eV0 := region.Edge(edgeIdx).V0 - eV1 := region.Edge(edgeIdx).V1 - if useSphereOrSpheroid == UseSpheroid { - edgeWeight = spheroidDistance(spheroid, eV0, eV1) - } else { - edgeWeight = float64(s2.ChordAngleBetweenPoints(eV0, eV1).Angle()) - } - localWeightedCentroids = append(localWeightedCentroids, s2.Point{Vector: eV0.Add(eV1.Vector).Mul(edgeWeight)}) - } - case *s2.Polygon: - // The algorithm used for the calculation of centroid for (Multi)Polygon: - // * Split (Multi)Polygon in the set of individual triangles. - // * Calculate the centroid and signed area (negative area for triangle inside - // the hole) for all the triangle. - // * The centroid of (Multi)Polygon will be a weighted average of the centroid - // of all the triangle, where each centroid is weighted by its area. - for _, loop := range region.Loops() { - triangleVertices := make([]s2.Point, 4) - triangleVertices[0] = loop.Vertex(0) - triangleVertices[3] = loop.Vertex(0) - - for pointIdx := 1; pointIdx+2 < loop.NumVertices(); pointIdx++ { - triangleVertices[1] = loop.Vertex(pointIdx) - triangleVertices[2] = loop.Vertex(pointIdx + 1) - triangleCentroid := s2.PlanarCentroid(triangleVertices[0], triangleVertices[1], triangleVertices[2]) - var area float64 - if useSphereOrSpheroid == UseSpheroid { - area, _ = spheroid.AreaAndPerimeter(triangleVertices[:3]) - } else { - area = s2.LoopFromPoints(triangleVertices).Area() - } - area = area * float64(loop.Sign()) - localWeightedCentroids = append(localWeightedCentroids, s2.Point{Vector: triangleCentroid.Mul(area)}) - } - } - } - } - var centroidVector r3.Vector - for _, point := range localWeightedCentroids { - centroidVector = centroidVector.Add(point.Vector) - } - latLng := s2.LatLngFromPoint(s2.Point{Vector: centroidVector.Normalize()}) - centroid := geom.NewPointFlat(geom.XY, []float64{latLng.Lng.Degrees(), latLng.Lat.Degrees()}).SetSRID(int(g.SRID())) - return geo.MakeGeographyFromGeomT(centroid) -} diff --git a/postgres/parser/geo/geogfn/unary_operators.go b/postgres/parser/geo/geogfn/unary_operators.go deleted file mode 100644 index e5334272f4..0000000000 --- a/postgres/parser/geo/geogfn/unary_operators.go +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geogfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/golang/geo/s1" - "github.com/golang/geo/s2" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geographiclib" -) - -// Area returns the area of a given Geography. -func Area(g geo.Geography, useSphereOrSpheroid UseSphereOrSpheroid) (float64, error) { - regions, err := g.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return 0, err - } - spheroid, err := g.Spheroid() - if err != nil { - return 0, err - } - - var totalArea float64 - for _, region := range regions { - switch region := region.(type) { - case s2.Point, *s2.Polyline: - case *s2.Polygon: - if useSphereOrSpheroid == UseSpheroid { - for _, loop := range region.Loops() { - points := loop.Vertices() - area, _ := spheroid.AreaAndPerimeter(points[:len(points)-1]) - totalArea += float64(loop.Sign()) * area - } - } else { - totalArea += region.Area() - } - default: - return 0, errors.Newf("unknown type: %T", region) - } - } - if useSphereOrSpheroid == UseSphere { - totalArea *= spheroid.SphereRadius * spheroid.SphereRadius - } - return totalArea, nil -} - -// Perimeter returns the perimeter of a given Geography. -func Perimeter(g geo.Geography, useSphereOrSpheroid UseSphereOrSpheroid) (float64, error) { - gt, err := g.AsGeomT() - if err != nil { - return 0, err - } - // This check mirrors PostGIS behavior, where GeometryCollections - // of LineStrings include the length for perimeters. - switch gt.(type) { - case *geom.Polygon, *geom.MultiPolygon, *geom.GeometryCollection: - default: - return 0, nil - } - regions, err := geo.S2RegionsFromGeomT(gt, geo.EmptyBehaviorOmit) - if err != nil { - return 0, err - } - spheroid, err := g.Spheroid() - if err != nil { - return 0, err - } - return length(regions, spheroid, useSphereOrSpheroid) -} - -// Length returns length of a given Geography. -func Length(g geo.Geography, useSphereOrSpheroid UseSphereOrSpheroid) (float64, error) { - gt, err := g.AsGeomT() - if err != nil { - return 0, err - } - // This check mirrors PostGIS behavior, where GeometryCollections - // of Polygons include the perimeters for polygons. - switch gt.(type) { - case *geom.LineString, *geom.MultiLineString, *geom.GeometryCollection: - default: - return 0, nil - } - regions, err := geo.S2RegionsFromGeomT(gt, geo.EmptyBehaviorOmit) - if err != nil { - return 0, err - } - spheroid, err := g.Spheroid() - if err != nil { - return 0, err - } - return length(regions, spheroid, useSphereOrSpheroid) -} - -// Project returns calculate a projected point given a source point, a distance and a azimuth. -func Project(g geo.Geography, distance float64, azimuth s1.Angle) (geo.Geography, error) { - geomT, err := g.AsGeomT() - if err != nil { - return geo.Geography{}, err - } - - point, ok := geomT.(*geom.Point) - if !ok { - return geo.Geography{}, errors.Newf("ST_Project(geography) is only valid for point inputs") - } - - spheroid, err := g.Spheroid() - if err != nil { - return geo.Geography{}, err - } - - // Normalize distance to be positive. - if distance < 0.0 { - distance = -distance - azimuth += math.Pi - } - - // Normalize azimuth - azimuth = azimuth.Normalized() - - // Check the distance validity. - if distance > (math.Pi * spheroid.Radius) { - return geo.Geography{}, errors.Newf("distance must not be greater than %f", math.Pi*spheroid.Radius) - } - - if point.Empty() { - return geo.Geography{}, errors.Newf("cannot project POINT EMPTY") - } - - // Convert to ta geodetic point. - x := point.X() - y := point.Y() - - projected := spheroid.Project( - s2.LatLngFromDegrees(x, y), - distance, - azimuth, - ) - - ret := geom.NewPointFlat( - geom.XY, - []float64{ - geo.NormalizeLongitudeDegrees(projected.Lng.Degrees()), - geo.NormalizeLatitudeDegrees(projected.Lat.Degrees()), - }, - ).SetSRID(point.SRID()) - return geo.MakeGeographyFromGeomT(ret) -} - -// length returns the sum of the lengtsh and perimeters in the shapes of the Geography. -// In OGC parlance, length returns both LineString lengths _and_ Polygon perimeters. -func length( - regions []s2.Region, spheroid *geographiclib.Spheroid, useSphereOrSpheroid UseSphereOrSpheroid, -) (float64, error) { - var totalLength float64 - for _, region := range regions { - switch region := region.(type) { - case s2.Point: - case *s2.Polyline: - if useSphereOrSpheroid == UseSpheroid { - totalLength += spheroid.InverseBatch((*region)) - } else { - for edgeIdx := 0; edgeIdx < region.NumEdges(); edgeIdx++ { - edge := region.Edge(edgeIdx) - totalLength += s2.ChordAngleBetweenPoints(edge.V0, edge.V1).Angle().Radians() - } - } - case *s2.Polygon: - for _, loop := range region.Loops() { - if useSphereOrSpheroid == UseSpheroid { - totalLength += spheroid.InverseBatch(loop.Vertices()) - } else { - for edgeIdx := 0; edgeIdx < loop.NumEdges(); edgeIdx++ { - edge := loop.Edge(edgeIdx) - totalLength += s2.ChordAngleBetweenPoints(edge.V0, edge.V1).Angle().Radians() - } - } - } - default: - return 0, errors.Newf("unknown type: %T", region) - } - } - if useSphereOrSpheroid == UseSphere { - totalLength *= spheroid.SphereRadius - } - return totalLength, nil -} diff --git a/postgres/parser/geo/geoindex/config.pb.go b/postgres/parser/geo/geoindex/config.pb.go deleted file mode 100644 index b7235470a1..0000000000 --- a/postgres/parser/geo/geoindex/config.pb.go +++ /dev/null @@ -1,1172 +0,0 @@ -// Code generated by protoc-gen-gogo. DO NOT EDIT. -// source: geo/geoindex/config.proto - -package geoindex - -import ( - encoding_binary "encoding/binary" - fmt "fmt" - io "io" - math "math" - - proto "github.com/gogo/protobuf/proto" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package - -// Config is the information used to tune one instance of a geospatial index. -// Each SQL index will have its own config. -// -// At the moment, only one major indexing strategy is implemented (S2 cells). -type Config struct { - S2Geography *S2GeographyConfig `protobuf:"bytes,1,opt,name=s2_geography,json=s2Geography,proto3" json:"s2_geography,omitempty"` - S2Geometry *S2GeometryConfig `protobuf:"bytes,2,opt,name=s2_geometry,json=s2Geometry,proto3" json:"s2_geometry,omitempty"` -} - -func (m *Config) Reset() { *m = Config{} } -func (m *Config) String() string { return proto.CompactTextString(m) } -func (*Config) ProtoMessage() {} -func (*Config) Descriptor() ([]byte, []int) { - return fileDescriptor_config_4fdfa32e25381f1e, []int{0} -} -func (m *Config) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *Config) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - b = b[:cap(b)] - n, err := m.MarshalTo(b) - if err != nil { - return nil, err - } - return b[:n], nil -} -func (dst *Config) XXX_Merge(src proto.Message) { - xxx_messageInfo_Config.Merge(dst, src) -} -func (m *Config) XXX_Size() int { - return m.Size() -} -func (m *Config) XXX_DiscardUnknown() { - xxx_messageInfo_Config.DiscardUnknown(m) -} - -var xxx_messageInfo_Config proto.InternalMessageInfo - -// S2Config is the required information to tune one instance of an S2 cell -// backed geospatial index. For advanced users only -- the defaults should be -// good enough. -// -// TODO(sumeer): Based on experiments, reduce the knobs below by making the -// covering self-tuning. -type S2Config struct { - // MinLevel is the minimum cell level stored in the index. If left unset, it - // defaults to 0. - MinLevel int32 `protobuf:"varint,1,opt,name=min_level,json=minLevel,proto3" json:"min_level,omitempty"` - // MaxLevel is the maximum cell level stored in the index. If left unset, it - // defaults to 30. - MaxLevel int32 `protobuf:"varint,2,opt,name=max_level,json=maxLevel,proto3" json:"max_level,omitempty"` - // `MaxLevel-MinLevel` must be an exact multiple of LevelMod. If left unset, - // it defaults to 1. - LevelMod int32 `protobuf:"varint,3,opt,name=level_mod,json=levelMod,proto3" json:"level_mod,omitempty"` - // MaxCells is a soft hint for the maximum number of entries used to store a - // single geospatial object. If left unset, it defaults to 4. - MaxCells int32 `protobuf:"varint,4,opt,name=max_cells,json=maxCells,proto3" json:"max_cells,omitempty"` -} - -func (m *S2Config) Reset() { *m = S2Config{} } -func (m *S2Config) String() string { return proto.CompactTextString(m) } -func (*S2Config) ProtoMessage() {} -func (*S2Config) Descriptor() ([]byte, []int) { - return fileDescriptor_config_4fdfa32e25381f1e, []int{1} -} -func (m *S2Config) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *S2Config) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - b = b[:cap(b)] - n, err := m.MarshalTo(b) - if err != nil { - return nil, err - } - return b[:n], nil -} -func (dst *S2Config) XXX_Merge(src proto.Message) { - xxx_messageInfo_S2Config.Merge(dst, src) -} -func (m *S2Config) XXX_Size() int { - return m.Size() -} -func (m *S2Config) XXX_DiscardUnknown() { - xxx_messageInfo_S2Config.DiscardUnknown(m) -} - -var xxx_messageInfo_S2Config proto.InternalMessageInfo - -type S2GeographyConfig struct { - S2Config *S2Config `protobuf:"bytes,1,opt,name=s2_config,json=s2Config,proto3" json:"s2_config,omitempty"` -} - -func (m *S2GeographyConfig) Reset() { *m = S2GeographyConfig{} } -func (m *S2GeographyConfig) String() string { return proto.CompactTextString(m) } -func (*S2GeographyConfig) ProtoMessage() {} -func (*S2GeographyConfig) Descriptor() ([]byte, []int) { - return fileDescriptor_config_4fdfa32e25381f1e, []int{2} -} -func (m *S2GeographyConfig) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *S2GeographyConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - b = b[:cap(b)] - n, err := m.MarshalTo(b) - if err != nil { - return nil, err - } - return b[:n], nil -} -func (dst *S2GeographyConfig) XXX_Merge(src proto.Message) { - xxx_messageInfo_S2GeographyConfig.Merge(dst, src) -} -func (m *S2GeographyConfig) XXX_Size() int { - return m.Size() -} -func (m *S2GeographyConfig) XXX_DiscardUnknown() { - xxx_messageInfo_S2GeographyConfig.DiscardUnknown(m) -} - -var xxx_messageInfo_S2GeographyConfig proto.InternalMessageInfo - -type S2GeometryConfig struct { - // The rectangle bounds of the plane that will be efficiently indexed. Shapes - // should rarely exceed these bounds. - MinX float64 `protobuf:"fixed64,1,opt,name=min_x,json=minX,proto3" json:"min_x,omitempty"` - MaxX float64 `protobuf:"fixed64,2,opt,name=max_x,json=maxX,proto3" json:"max_x,omitempty"` - MinY float64 `protobuf:"fixed64,3,opt,name=min_y,json=minY,proto3" json:"min_y,omitempty"` - MaxY float64 `protobuf:"fixed64,4,opt,name=max_y,json=maxY,proto3" json:"max_y,omitempty"` - S2Config *S2Config `protobuf:"bytes,5,opt,name=s2_config,json=s2Config,proto3" json:"s2_config,omitempty"` -} - -func (m *S2GeometryConfig) Reset() { *m = S2GeometryConfig{} } -func (m *S2GeometryConfig) String() string { return proto.CompactTextString(m) } -func (*S2GeometryConfig) ProtoMessage() {} -func (*S2GeometryConfig) Descriptor() ([]byte, []int) { - return fileDescriptor_config_4fdfa32e25381f1e, []int{3} -} -func (m *S2GeometryConfig) XXX_Unmarshal(b []byte) error { - return m.Unmarshal(b) -} -func (m *S2GeometryConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - b = b[:cap(b)] - n, err := m.MarshalTo(b) - if err != nil { - return nil, err - } - return b[:n], nil -} -func (dst *S2GeometryConfig) XXX_Merge(src proto.Message) { - xxx_messageInfo_S2GeometryConfig.Merge(dst, src) -} -func (m *S2GeometryConfig) XXX_Size() int { - return m.Size() -} -func (m *S2GeometryConfig) XXX_DiscardUnknown() { - xxx_messageInfo_S2GeometryConfig.DiscardUnknown(m) -} - -var xxx_messageInfo_S2GeometryConfig proto.InternalMessageInfo - -func init() { - proto.RegisterType((*Config)(nil), "cockroach.geo.geoindex.Config") - proto.RegisterType((*S2Config)(nil), "cockroach.geo.geoindex.S2Config") - proto.RegisterType((*S2GeographyConfig)(nil), "cockroach.geo.geoindex.S2GeographyConfig") - proto.RegisterType((*S2GeometryConfig)(nil), "cockroach.geo.geoindex.S2GeometryConfig") -} -func (this *Config) Equal(that interface{}) bool { - if that == nil { - return this == nil - } - - that1, ok := that.(*Config) - if !ok { - that2, ok := that.(Config) - if ok { - that1 = &that2 - } else { - return false - } - } - if that1 == nil { - return this == nil - } else if this == nil { - return false - } - if !this.S2Geography.Equal(that1.S2Geography) { - return false - } - if !this.S2Geometry.Equal(that1.S2Geometry) { - return false - } - return true -} -func (this *S2Config) Equal(that interface{}) bool { - if that == nil { - return this == nil - } - - that1, ok := that.(*S2Config) - if !ok { - that2, ok := that.(S2Config) - if ok { - that1 = &that2 - } else { - return false - } - } - if that1 == nil { - return this == nil - } else if this == nil { - return false - } - if this.MinLevel != that1.MinLevel { - return false - } - if this.MaxLevel != that1.MaxLevel { - return false - } - if this.LevelMod != that1.LevelMod { - return false - } - if this.MaxCells != that1.MaxCells { - return false - } - return true -} -func (this *S2GeographyConfig) Equal(that interface{}) bool { - if that == nil { - return this == nil - } - - that1, ok := that.(*S2GeographyConfig) - if !ok { - that2, ok := that.(S2GeographyConfig) - if ok { - that1 = &that2 - } else { - return false - } - } - if that1 == nil { - return this == nil - } else if this == nil { - return false - } - if !this.S2Config.Equal(that1.S2Config) { - return false - } - return true -} -func (this *S2GeometryConfig) Equal(that interface{}) bool { - if that == nil { - return this == nil - } - - that1, ok := that.(*S2GeometryConfig) - if !ok { - that2, ok := that.(S2GeometryConfig) - if ok { - that1 = &that2 - } else { - return false - } - } - if that1 == nil { - return this == nil - } else if this == nil { - return false - } - if this.MinX != that1.MinX { - return false - } - if this.MaxX != that1.MaxX { - return false - } - if this.MinY != that1.MinY { - return false - } - if this.MaxY != that1.MaxY { - return false - } - if !this.S2Config.Equal(that1.S2Config) { - return false - } - return true -} -func (m *Config) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalTo(dAtA) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *Config) MarshalTo(dAtA []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.S2Geography != nil { - dAtA[i] = 0xa - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.S2Geography.Size())) - n1, err := m.S2Geography.MarshalTo(dAtA[i:]) - if err != nil { - return 0, err - } - i += n1 - } - if m.S2Geometry != nil { - dAtA[i] = 0x12 - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.S2Geometry.Size())) - n2, err := m.S2Geometry.MarshalTo(dAtA[i:]) - if err != nil { - return 0, err - } - i += n2 - } - return i, nil -} - -func (m *S2Config) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalTo(dAtA) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *S2Config) MarshalTo(dAtA []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.MinLevel != 0 { - dAtA[i] = 0x8 - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.MinLevel)) - } - if m.MaxLevel != 0 { - dAtA[i] = 0x10 - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.MaxLevel)) - } - if m.LevelMod != 0 { - dAtA[i] = 0x18 - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.LevelMod)) - } - if m.MaxCells != 0 { - dAtA[i] = 0x20 - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.MaxCells)) - } - return i, nil -} - -func (m *S2GeographyConfig) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalTo(dAtA) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *S2GeographyConfig) MarshalTo(dAtA []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.S2Config != nil { - dAtA[i] = 0xa - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.S2Config.Size())) - n3, err := m.S2Config.MarshalTo(dAtA[i:]) - if err != nil { - return 0, err - } - i += n3 - } - return i, nil -} - -func (m *S2GeometryConfig) Marshal() (dAtA []byte, err error) { - size := m.Size() - dAtA = make([]byte, size) - n, err := m.MarshalTo(dAtA) - if err != nil { - return nil, err - } - return dAtA[:n], nil -} - -func (m *S2GeometryConfig) MarshalTo(dAtA []byte) (int, error) { - var i int - _ = i - var l int - _ = l - if m.MinX != 0 { - dAtA[i] = 0x9 - i++ - encoding_binary.LittleEndian.PutUint64(dAtA[i:], uint64(math.Float64bits(float64(m.MinX)))) - i += 8 - } - if m.MaxX != 0 { - dAtA[i] = 0x11 - i++ - encoding_binary.LittleEndian.PutUint64(dAtA[i:], uint64(math.Float64bits(float64(m.MaxX)))) - i += 8 - } - if m.MinY != 0 { - dAtA[i] = 0x19 - i++ - encoding_binary.LittleEndian.PutUint64(dAtA[i:], uint64(math.Float64bits(float64(m.MinY)))) - i += 8 - } - if m.MaxY != 0 { - dAtA[i] = 0x21 - i++ - encoding_binary.LittleEndian.PutUint64(dAtA[i:], uint64(math.Float64bits(float64(m.MaxY)))) - i += 8 - } - if m.S2Config != nil { - dAtA[i] = 0x2a - i++ - i = encodeVarintConfig(dAtA, i, uint64(m.S2Config.Size())) - n4, err := m.S2Config.MarshalTo(dAtA[i:]) - if err != nil { - return 0, err - } - i += n4 - } - return i, nil -} - -func encodeVarintConfig(dAtA []byte, offset int, v uint64) int { - for v >= 1<<7 { - dAtA[offset] = uint8(v&0x7f | 0x80) - v >>= 7 - offset++ - } - dAtA[offset] = uint8(v) - return offset + 1 -} -func (m *Config) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.S2Geography != nil { - l = m.S2Geography.Size() - n += 1 + l + sovConfig(uint64(l)) - } - if m.S2Geometry != nil { - l = m.S2Geometry.Size() - n += 1 + l + sovConfig(uint64(l)) - } - return n -} - -func (m *S2Config) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.MinLevel != 0 { - n += 1 + sovConfig(uint64(m.MinLevel)) - } - if m.MaxLevel != 0 { - n += 1 + sovConfig(uint64(m.MaxLevel)) - } - if m.LevelMod != 0 { - n += 1 + sovConfig(uint64(m.LevelMod)) - } - if m.MaxCells != 0 { - n += 1 + sovConfig(uint64(m.MaxCells)) - } - return n -} - -func (m *S2GeographyConfig) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.S2Config != nil { - l = m.S2Config.Size() - n += 1 + l + sovConfig(uint64(l)) - } - return n -} - -func (m *S2GeometryConfig) Size() (n int) { - if m == nil { - return 0 - } - var l int - _ = l - if m.MinX != 0 { - n += 9 - } - if m.MaxX != 0 { - n += 9 - } - if m.MinY != 0 { - n += 9 - } - if m.MaxY != 0 { - n += 9 - } - if m.S2Config != nil { - l = m.S2Config.Size() - n += 1 + l + sovConfig(uint64(l)) - } - return n -} - -func sovConfig(x uint64) (n int) { - for { - n++ - x >>= 7 - if x == 0 { - break - } - } - return n -} -func sozConfig(x uint64) (n int) { - return sovConfig(uint64((x << 1) ^ uint64((int64(x) >> 63)))) -} -func (this *Config) GetValue() interface{} { - if this.S2Geography != nil { - return this.S2Geography - } - if this.S2Geometry != nil { - return this.S2Geometry - } - return nil -} - -func (this *Config) SetValue(value interface{}) bool { - switch vt := value.(type) { - case *S2GeographyConfig: - this.S2Geography = vt - case *S2GeometryConfig: - this.S2Geometry = vt - default: - return false - } - return true -} -func (m *Config) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: Config: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: Config: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field S2Geography", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthConfig - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.S2Geography == nil { - m.S2Geography = &S2GeographyConfig{} - } - if err := m.S2Geography.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - case 2: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field S2Geometry", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthConfig - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.S2Geometry == nil { - m.S2Geometry = &S2GeometryConfig{} - } - if err := m.S2Geometry.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipConfig(dAtA[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthConfig - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *S2Config) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: S2Config: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: S2Config: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field MinLevel", wireType) - } - m.MinLevel = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.MinLevel |= (int32(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - case 2: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field MaxLevel", wireType) - } - m.MaxLevel = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.MaxLevel |= (int32(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - case 3: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field LevelMod", wireType) - } - m.LevelMod = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.LevelMod |= (int32(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - case 4: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field MaxCells", wireType) - } - m.MaxCells = 0 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - m.MaxCells |= (int32(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - default: - iNdEx = preIndex - skippy, err := skipConfig(dAtA[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthConfig - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *S2GeographyConfig) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: S2GeographyConfig: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: S2GeographyConfig: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field S2Config", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthConfig - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.S2Config == nil { - m.S2Config = &S2Config{} - } - if err := m.S2Config.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipConfig(dAtA[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthConfig - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func (m *S2GeometryConfig) Unmarshal(dAtA []byte) error { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - preIndex := iNdEx - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - fieldNum := int32(wire >> 3) - wireType := int(wire & 0x7) - if wireType == 4 { - return fmt.Errorf("proto: S2GeometryConfig: wiretype end group for non-group") - } - if fieldNum <= 0 { - return fmt.Errorf("proto: S2GeometryConfig: illegal tag %d (wire type %d)", fieldNum, wire) - } - switch fieldNum { - case 1: - if wireType != 1 { - return fmt.Errorf("proto: wrong wireType = %d for field MinX", wireType) - } - var v uint64 - if (iNdEx + 8) > l { - return io.ErrUnexpectedEOF - } - v = uint64(encoding_binary.LittleEndian.Uint64(dAtA[iNdEx:])) - iNdEx += 8 - m.MinX = float64(math.Float64frombits(v)) - case 2: - if wireType != 1 { - return fmt.Errorf("proto: wrong wireType = %d for field MaxX", wireType) - } - var v uint64 - if (iNdEx + 8) > l { - return io.ErrUnexpectedEOF - } - v = uint64(encoding_binary.LittleEndian.Uint64(dAtA[iNdEx:])) - iNdEx += 8 - m.MaxX = float64(math.Float64frombits(v)) - case 3: - if wireType != 1 { - return fmt.Errorf("proto: wrong wireType = %d for field MinY", wireType) - } - var v uint64 - if (iNdEx + 8) > l { - return io.ErrUnexpectedEOF - } - v = uint64(encoding_binary.LittleEndian.Uint64(dAtA[iNdEx:])) - iNdEx += 8 - m.MinY = float64(math.Float64frombits(v)) - case 4: - if wireType != 1 { - return fmt.Errorf("proto: wrong wireType = %d for field MaxY", wireType) - } - var v uint64 - if (iNdEx + 8) > l { - return io.ErrUnexpectedEOF - } - v = uint64(encoding_binary.LittleEndian.Uint64(dAtA[iNdEx:])) - iNdEx += 8 - m.MaxY = float64(math.Float64frombits(v)) - case 5: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field S2Config", wireType) - } - var msglen int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return ErrIntOverflowConfig - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - msglen |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - if msglen < 0 { - return ErrInvalidLengthConfig - } - postIndex := iNdEx + msglen - if postIndex > l { - return io.ErrUnexpectedEOF - } - if m.S2Config == nil { - m.S2Config = &S2Config{} - } - if err := m.S2Config.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { - return err - } - iNdEx = postIndex - default: - iNdEx = preIndex - skippy, err := skipConfig(dAtA[iNdEx:]) - if err != nil { - return err - } - if skippy < 0 { - return ErrInvalidLengthConfig - } - if (iNdEx + skippy) > l { - return io.ErrUnexpectedEOF - } - iNdEx += skippy - } - } - - if iNdEx > l { - return io.ErrUnexpectedEOF - } - return nil -} -func skipConfig(dAtA []byte) (n int, err error) { - l := len(dAtA) - iNdEx := 0 - for iNdEx < l { - var wire uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowConfig - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - wire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - wireType := int(wire & 0x7) - switch wireType { - case 0: - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowConfig - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - iNdEx++ - if dAtA[iNdEx-1] < 0x80 { - break - } - } - return iNdEx, nil - case 1: - iNdEx += 8 - return iNdEx, nil - case 2: - var length int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowConfig - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - length |= (int(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - iNdEx += length - if length < 0 { - return 0, ErrInvalidLengthConfig - } - return iNdEx, nil - case 3: - for { - var innerWire uint64 - var start int = iNdEx - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return 0, ErrIntOverflowConfig - } - if iNdEx >= l { - return 0, io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - innerWire |= (uint64(b) & 0x7F) << shift - if b < 0x80 { - break - } - } - innerWireType := int(innerWire & 0x7) - if innerWireType == 4 { - break - } - next, err := skipConfig(dAtA[start:]) - if err != nil { - return 0, err - } - iNdEx = start + next - } - return iNdEx, nil - case 4: - return iNdEx, nil - case 5: - iNdEx += 4 - return iNdEx, nil - default: - return 0, fmt.Errorf("proto: illegal wireType %d", wireType) - } - } - panic("unreachable") -} - -var ( - ErrInvalidLengthConfig = fmt.Errorf("proto: negative length found during unmarshaling") - ErrIntOverflowConfig = fmt.Errorf("proto: integer overflow") -) - -func init() { proto.RegisterFile("geo/geoindex/config.proto", fileDescriptor_config_4fdfa32e25381f1e) } - -var fileDescriptor_config_4fdfa32e25381f1e = []byte{ - // 376 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x92, 0xbf, 0x6e, 0xea, 0x30, - 0x18, 0xc5, 0x63, 0x2e, 0xa0, 0x60, 0xee, 0x70, 0x6f, 0xee, 0x55, 0x95, 0xb6, 0x92, 0x41, 0x4c, - 0xb4, 0x43, 0x90, 0xd2, 0x0d, 0xa9, 0x4b, 0x19, 0xaa, 0x4a, 0x74, 0x81, 0x05, 0xba, 0x44, 0x69, - 0x70, 0x4d, 0xd4, 0x24, 0x46, 0x09, 0xaa, 0x9c, 0xbd, 0x0f, 0xd0, 0x47, 0x60, 0xe7, 0x45, 0x18, - 0x19, 0x19, 0xdb, 0xb0, 0xf4, 0x31, 0xaa, 0x7c, 0x76, 0xa8, 0xfa, 0x77, 0xe8, 0x66, 0x7f, 0xe7, - 0x7c, 0x3f, 0x9d, 0xe3, 0x04, 0xef, 0x33, 0xca, 0x3b, 0x8c, 0x72, 0x3f, 0x9a, 0x50, 0xd1, 0xf1, - 0x78, 0x74, 0xe3, 0x33, 0x6b, 0x16, 0xf3, 0x39, 0x37, 0xf6, 0x3c, 0xee, 0xdd, 0xc6, 0xdc, 0xf5, - 0xa6, 0x16, 0xa3, 0xdc, 0x2a, 0x4c, 0x07, 0xff, 0x19, 0x67, 0x1c, 0x2c, 0x9d, 0xfc, 0x24, 0xdd, - 0xad, 0x25, 0xc2, 0xd5, 0x1e, 0xac, 0x1b, 0x7d, 0xfc, 0x3b, 0xb1, 0x1d, 0x46, 0x39, 0x8b, 0xdd, - 0xd9, 0x34, 0x35, 0x51, 0x13, 0xb5, 0xeb, 0xf6, 0x91, 0xf5, 0x39, 0xcf, 0x1a, 0xda, 0xe7, 0x85, - 0x55, 0x02, 0x06, 0xf5, 0xe4, 0x75, 0x64, 0x5c, 0xe0, 0xba, 0xa4, 0x85, 0x74, 0x1e, 0xa7, 0x66, - 0x09, 0x60, 0xed, 0x6f, 0x61, 0xe0, 0x54, 0x2c, 0x9c, 0xec, 0x26, 0x5d, 0x7d, 0xb5, 0x68, 0xa0, - 0xe7, 0x45, 0x03, 0xb5, 0xee, 0x11, 0xd6, 0x87, 0xb6, 0xca, 0x7b, 0x88, 0x6b, 0xa1, 0x1f, 0x39, - 0x01, 0xbd, 0xa3, 0x01, 0x84, 0xad, 0x0c, 0xf4, 0xd0, 0x8f, 0xfa, 0xf9, 0x1d, 0x44, 0x57, 0x28, - 0xb1, 0xa4, 0x44, 0x57, 0xec, 0x44, 0x10, 0x9c, 0x90, 0x4f, 0xcc, 0x5f, 0x52, 0x84, 0xc1, 0x25, - 0x9f, 0x14, 0x9b, 0x1e, 0x0d, 0x82, 0xc4, 0x2c, 0xef, 0x36, 0x7b, 0xf9, 0xbd, 0x5b, 0x86, 0x18, - 0x23, 0xfc, 0xf7, 0x43, 0x7b, 0xe3, 0x14, 0xd7, 0x12, 0xdb, 0x91, 0x9f, 0x42, 0xbd, 0x5d, 0xf3, - 0xeb, 0xba, 0xaa, 0xa6, 0x9e, 0xa8, 0x93, 0x22, 0x2f, 0x11, 0xfe, 0xf3, 0xfe, 0x2d, 0x8c, 0x7f, - 0xb8, 0x92, 0x17, 0x15, 0x40, 0x45, 0x83, 0x72, 0xe8, 0x47, 0x23, 0x18, 0xba, 0xc2, 0x11, 0x50, - 0x2e, 0x1f, 0xba, 0x62, 0x54, 0x38, 0x53, 0x28, 0x25, 0x9d, 0xe3, 0xc2, 0x99, 0x42, 0x19, 0xe9, - 0x1c, 0xbf, 0x4d, 0x5b, 0xf9, 0x59, 0xda, 0xb3, 0xe3, 0xd5, 0x13, 0xd1, 0x56, 0x19, 0x41, 0xeb, - 0x8c, 0xa0, 0x4d, 0x46, 0xd0, 0x63, 0x46, 0xd0, 0xc3, 0x96, 0x68, 0xeb, 0x2d, 0xd1, 0x36, 0x5b, - 0xa2, 0x5d, 0xe9, 0x05, 0xe4, 0xba, 0x0a, 0xff, 0xdb, 0xc9, 0x4b, 0x00, 0x00, 0x00, 0xff, 0xff, - 0xb1, 0xfa, 0x92, 0x6c, 0xba, 0x02, 0x00, 0x00, -} diff --git a/postgres/parser/geo/geoindex/config.proto b/postgres/parser/geo/geoindex/config.proto deleted file mode 100644 index 1e8a77685c..0000000000 --- a/postgres/parser/geo/geoindex/config.proto +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -syntax = "proto3"; -package cockroach.geo.geoindex; -option go_package = "geoindex"; - -import "gogoproto/gogo.proto"; - -// Config is the information used to tune one instance of a geospatial index. -// Each SQL index will have its own config. -// -// At the moment, only one major indexing strategy is implemented (S2 cells). -message Config { - option (gogoproto.equal) = true; - option (gogoproto.onlyone) = true; - S2GeographyConfig s2_geography = 1; - S2GeometryConfig s2_geometry = 2; -} - -// S2Config is the required information to tune one instance of an S2 cell -// backed geospatial index. For advanced users only -- the defaults should be -// good enough. -// -// TODO(sumeer): Based on experiments, reduce the knobs below by making the -// covering self-tuning. -message S2Config { - option (gogoproto.equal) = true; - // MinLevel is the minimum cell level stored in the index. If left unset, it - // defaults to 0. - int32 min_level = 1; - // MaxLevel is the maximum cell level stored in the index. If left unset, it - // defaults to 30. - int32 max_level = 2; - // `MaxLevel-MinLevel` must be an exact multiple of LevelMod. If left unset, - // it defaults to 1. - int32 level_mod = 3; - // MaxCells is a soft hint for the maximum number of entries used to store a - // single geospatial object. If left unset, it defaults to 4. - int32 max_cells = 4; -} - -message S2GeographyConfig { - option (gogoproto.equal) = true; - S2Config s2_config = 1; -} - -message S2GeometryConfig { - option (gogoproto.equal) = true; - // The rectangle bounds of the plane that will be efficiently indexed. Shapes - // should rarely exceed these bounds. - double min_x = 1; - double max_x = 2; - double min_y = 3; - double max_y = 4; - - S2Config s2_config = 5; -} diff --git a/postgres/parser/geo/geoindex/geoindex.go b/postgres/parser/geo/geoindex/geoindex.go deleted file mode 100644 index e51712c892..0000000000 --- a/postgres/parser/geo/geoindex/geoindex.go +++ /dev/null @@ -1,717 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geoindex - -import ( - "context" - "fmt" - "math" - "sort" - "strings" - - "github.com/golang/geo/s2" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geogfn" -) - -// RelationshipMap contains all the geospatial functions that can be index- -// accelerated. Each function implies a certain type of geospatial relationship, -// which affects how the index is queried as part of a constrained scan or -// geospatial lookup join. RelationshipMap maps the function name to its -// corresponding relationship (Covers, CoveredBy, DFullyWithin, DWithin or Intersects). -// -// Note that for all of these functions, a geospatial lookup join or constrained -// index scan may produce false positives. Therefore, the original function must -// be called on the output of the index operation to filter the results. -var RelationshipMap = map[string]RelationshipType{ - "st_covers": Covers, - "st_coveredby": CoveredBy, - "st_contains": Covers, - "st_containsproperly": Covers, - "st_crosses": Intersects, - "st_dwithin": DWithin, - "st_dfullywithin": DFullyWithin, - "st_equals": Intersects, - "st_intersects": Intersects, - "st_overlaps": Intersects, - "st_touches": Intersects, - "st_within": CoveredBy, - "st_dwithinexclusive": DWithin, - "st_dfullywithinexclusive": DFullyWithin, -} - -// RelationshipReverseMap contains a default function for each of the -// possible geospatial relationships. -var RelationshipReverseMap = map[RelationshipType]string{ - Covers: "st_covers", - CoveredBy: "st_coveredby", - DWithin: "st_dwithin", - DFullyWithin: "st_dfullywithin", - Intersects: "st_intersects", -} - -// CommuteRelationshipMap is used to determine how the geospatial -// relationship changes if the arguments to the index-accelerated function are -// commuted. -// -// The relationships in the RelationshipMap map above only apply when the -// second argument to the function is the indexed column. If the arguments are -// commuted so that the first argument is the indexed column, the relationship -// may change. -var CommuteRelationshipMap = map[RelationshipType]RelationshipType{ - Covers: CoveredBy, - CoveredBy: Covers, - DWithin: DWithin, - DFullyWithin: DFullyWithin, - Intersects: Intersects, -} - -// Interfaces for accelerating certain spatial operations, by allowing the -// caller to use an externally stored index. -// -// An index interface has methods that specify what to write or read from the -// stored index. It is the caller's responsibility to do the actual writes and -// reads. Index keys are represented using uint64 which implies that the index -// is transforming the 2D space (spherical or planar) to 1D space, say using a -// space-filling curve. The index can return false positives so query -// execution must use an exact filtering step after using the index. The -// current implementations use the S2 geometry library. -// -// The externally stored index must support key-value semantics and range -// queries over keys. The return values from Covers/CoveredBy/Intersects are -// specialized to the computation that needs to be performed: -// - Intersects needs to union (a) all the ranges corresponding to subtrees -// rooted at the covering of g, (b) all the parent nodes of the covering -// of g. The individual entries in (b) are representable as ranges of -// length 1. All of this is represented as UnionKeySpans. Covers, which -// is the shape that g covers, currently delegates to Interects so -// returns the same. -// -// - CoveredBy, which are the shapes that g is covered-by, needs to compute -// on the paths from each covering node to the root. For example, consider -// a quad-tree approach to dividing the space, where the nodes/cells -// covering g are c53, c61, c64. The corresponding ancestor path sets are -// {c53, c13, c3, c0}, {c61, c15, c3, c0}, {c64, c15, c3, c0}. Let I(c53) -// represent the index entries for c53. Then, -// I(c53) \union I(c13) \union I(c3) \union I(c0) -// represents all the shapes that cover cell c53. Similar union expressions -// can be constructed for all these paths. The computation needs to -// intersect these unions since all these cells need to be covered by a -// shape that covers g. One can extract the common sub-expressions to give -// I(c0) \union I(c3) \union -// ((I(c13) \union I(c53)) \intersection -// (I(c15) \union (I(c61) \intersection I(c64))) -// CoveredBy returns this factored expression in Reverse Polish notation. - -// GeographyIndex is an index over the unit sphere. -type GeographyIndex interface { - // InvertedIndexKeys returns the keys to store this object under when adding - // it to the index. - InvertedIndexKeys(c context.Context, g geo.Geography) ([]Key, error) - - // Acceleration for topological relationships (see - // https://postgis.net/docs/reference.html#Spatial_Relationships). Distance - // relationships can be accelerated by adjusting g before calling these - // functions. Bounding box operators are not accelerated since we do not - // index the bounding box -- bounding box queries are an implementation - // detail of a particular indexing approach in PostGIS and are not part of - // the OGC or SQL/MM specs. - - // Covers returns the index spans to read and union for the relationship - // ST_Covers(g, x), where x are the indexed geometries. - Covers(c context.Context, g geo.Geography) (UnionKeySpans, error) - - // CoveredBy returns the index entries to read and the expression to compute - // for ST_CoveredBy(g, x), where x are the indexed geometries. - CoveredBy(c context.Context, g geo.Geography) (RPKeyExpr, error) - - // Intersects returns the index spans to read and union for the relationship - // ST_Intersects(g, x), where x are the indexed geometries. - Intersects(c context.Context, g geo.Geography) (UnionKeySpans, error) - - // DWithin returns the index spans to read and union for the relationship - // ST_DWithin(g, x, distanceMeters). That is, there exists a part of - // geometry g that is within distanceMeters of x, where x is an indexed - // geometry. This function assumes a sphere. - DWithin( - c context.Context, g geo.Geography, distanceMeters float64, - useSphereOrSpheroid geogfn.UseSphereOrSpheroid, - ) (UnionKeySpans, error) - - // TestingInnerCovering returns an inner covering of g. - TestingInnerCovering(g geo.Geography) s2.CellUnion -} - -// GeometryIndex is an index over 2D cartesian coordinates. -type GeometryIndex interface { - // InvertedIndexKeys returns the keys to store this object under when adding - // it to the index. - InvertedIndexKeys(c context.Context, g geo.Geometry) ([]Key, error) - - // Acceleration for topological relationships (see - // https://postgis.net/docs/reference.html#Spatial_Relationships). Distance - // relationships can be accelerated by adjusting g before calling these - // functions. Bounding box operators are not accelerated since we do not - // index the bounding box -- bounding box queries are an implementation - // detail of a particular indexing approach in PostGIS and are not part of - // the OGC or SQL/MM specs. - - // Covers returns the index spans to read and union for the relationship - // ST_Covers(g, x), where x are the indexed geometries. - Covers(c context.Context, g geo.Geometry) (UnionKeySpans, error) - - // CoveredBy returns the index entries to read and the expression to compute - // for ST_CoveredBy(g, x), where x are the indexed geometries. - CoveredBy(c context.Context, g geo.Geometry) (RPKeyExpr, error) - - // Intersects returns the index spans to read and union for the relationship - // ST_Intersects(g, x), where x are the indexed geometries. - Intersects(c context.Context, g geo.Geometry) (UnionKeySpans, error) - - // DWithin returns the index spans to read and union for the relationship - // ST_DWithin(g, x, distance). That is, there exists a part of geometry g - // that is within distance units of x, where x is an indexed geometry. - DWithin(c context.Context, g geo.Geometry, distance float64) (UnionKeySpans, error) - - // DFullyWithin returns the index spans to read and union for the - // relationship ST_DFullyWithin(g, x, distance). That is, the maximum distance - // across every pair of points comprising geometries g and x is within distance - // units, where x is an indexed geometry. - DFullyWithin(c context.Context, g geo.Geometry, distance float64) (UnionKeySpans, error) - - // TestingInnerCovering returns an inner covering of g. - TestingInnerCovering(g geo.Geometry) s2.CellUnion -} - -// RelationshipType stores a type of geospatial relationship query that can -// be accelerated using an index. -type RelationshipType uint8 - -const ( - // Covers corresponds to the relationship in which one geospatial object - // covers another geospatial object. - Covers RelationshipType = (1 << iota) - - // CoveredBy corresponds to the relationship in which one geospatial object - // is covered by another geospatial object. - CoveredBy - - // Intersects corresponds to the relationship in which one geospatial object - // intersects another geospatial object. - Intersects - - // DWithin corresponds to a relationship where there exists a part of one - // geometry within d distance units of the other geometry. - DWithin - - // DFullyWithin corresponds to a relationship where every pair of points in - // two geometries are within d distance units. - DFullyWithin -) - -var geoRelationshipTypeStr = map[RelationshipType]string{ - Covers: "covers", - CoveredBy: "covered by", - Intersects: "intersects", -} - -func (gr RelationshipType) String() string { - return geoRelationshipTypeStr[gr] -} - -// IsEmptyConfig returns whether the given config contains a geospatial index -// configuration. -func IsEmptyConfig(cfg *Config) bool { - if cfg == nil { - return true - } - return cfg.S2Geography == nil && cfg.S2Geometry == nil -} - -// IsGeographyConfig returns whether the config is a geography geospatial -// index configuration. -func IsGeographyConfig(cfg *Config) bool { - if cfg == nil { - return false - } - return cfg.S2Geography != nil -} - -// IsGeometryConfig returns whether the config is a geometry geospatial -// index configuration. -func IsGeometryConfig(cfg *Config) bool { - if cfg == nil { - return false - } - return cfg.S2Geometry != nil -} - -// Key is one entry under which a geospatial shape is stored on behalf of an -// Index. The index is of the form (Key, Primary Key). -type Key uint64 - -// rpExprElement implements the RPExprElement interface. -func (k Key) rpExprElement() {} - -func (k Key) String() string { - c := s2.CellID(k) - if !c.IsValid() { - return "spilled" - } - var b strings.Builder - b.WriteByte('F') - b.WriteByte("012345"[c.Face()]) - fmt.Fprintf(&b, "/L%d/", c.Level()) - for level := 1; level <= c.Level(); level++ { - b.WriteByte("0123"[c.ChildPosition(level)]) - } - return b.String() -} - -// KeySpan represents a range of Keys. -type KeySpan struct { - // Both Start and End are inclusive, i.e., [Start, End]. - Start, End Key -} - -// UnionKeySpans is the set of indexed spans to retrieve and combine via set -// union. The spans are guaranteed to be non-overlapping and sorted in -// increasing order. Duplicate primary keys will not be retrieved by any -// individual key, but they may be present if more than one key is retrieved -// (including duplicates in a single span where End - Start > 1). -type UnionKeySpans []KeySpan - -func (s UnionKeySpans) String() string { - return s.toString(math.MaxInt32) -} - -func (s UnionKeySpans) toString(wrap int) string { - b := newStringBuilderWithWrap(&strings.Builder{}, wrap) - for i, span := range s { - if span.Start == span.End { - fmt.Fprintf(b, "%s", span.Start) - } else { - fmt.Fprintf(b, "[%s, %s]", span.Start, span.End) - } - if i != len(s)-1 { - b.WriteString(", ") - } - b.tryWrap() - } - return b.String() -} - -// RPExprElement is an element in the Reverse Polish notation expression. -// It is implemented by Key and RPSetOperator. -type RPExprElement interface { - rpExprElement() -} - -// RPSetOperator is a set operator in the Reverse Polish notation expression. -type RPSetOperator int - -const ( - // RPSetUnion is the union operator. - RPSetUnion RPSetOperator = iota + 1 - - // RPSetIntersection is the intersection operator. - RPSetIntersection -) - -// rpExprElement implements the RPExprElement interface. -func (o RPSetOperator) rpExprElement() {} - -// RPKeyExpr is an expression to evaluate over primary keys retrieved for -// index keys. If we view each index key as a posting list of primary keys, -// the expression involves union and intersection over the sets represented by -// each posting list. For S2, this expression represents an intersection of -// ancestors of different keys (cell ids) and is likely to contain many common -// keys. This special structure allows us to efficiently and easily eliminate -// common sub-expressions, hence the interface presents the factored -// expression. The expression is represented in Reverse Polish notation. -type RPKeyExpr []RPExprElement - -func (x RPKeyExpr) String() string { - var elements []string - for _, e := range x { - switch elem := e.(type) { - case Key: - elements = append(elements, elem.String()) - case RPSetOperator: - switch elem { - case RPSetUnion: - elements = append(elements, `\U`) - case RPSetIntersection: - elements = append(elements, `\I`) - } - } - } - return strings.Join(elements, " ") -} - -// Helper functions for index implementations that use the S2 geometry -// library. - -// covererInterface provides a covering for a set of regions. -type covererInterface interface { - // covering returns a normalized CellUnion, i.e., it is sorted, and does not - // contain redundancy. - covering(regions []s2.Region) s2.CellUnion -} - -// simpleCovererImpl is an implementation of covererInterface that delegates -// to s2.RegionCoverer. -type simpleCovererImpl struct { - rc *s2.RegionCoverer -} - -var _ covererInterface = simpleCovererImpl{} - -func (rc simpleCovererImpl) covering(regions []s2.Region) s2.CellUnion { - // TODO(sumeer): Add a max cells constraint for the whole covering, - // to respect the index configuration. - var u s2.CellUnion - for _, r := range regions { - u = append(u, rc.rc.Covering(r)...) - } - // Ensure the cells are non-overlapping. - u.Normalize() - return u -} - -// The "inner covering", for shape s, represented by the regions parameter, is -// used to find shapes that contain shape s. A regular covering that includes -// a cell c not completely covered by shape s could result in false negatives, -// since shape x that covers shape s could use a finer cell covering (using -// cells below c). For example, consider a portion of the cell quad-tree -// below: -// -// c0 -// | -// c3 -// | -// +---+---+ -// | | -// c13 c15 -// | | -// c53 +--+--+ -// | | -// c61 c64 -// -// Shape s could have a regular covering c15, c53, where c15 has 4 child cells -// c61..c64, and shape s only intersects wit c61, c64. A different shape x -// that covers shape s may have a covering c61, c64, c53. That is, it has used -// the finer cells c61, c64. If we used both regular coverings it is hard to -// know that x covers g. Hence, we compute the "inner covering" of g (defined -// below). -// -// The interior covering of shape s includes only cells covered by s. This is -// computed by RegionCoverer.InteriorCovering() and is intuitively what we -// need. But the interior covering is naturally empty for points and lines -// (and can be empty for polygons too), and an empty covering is not useful -// for constraining index lookups. We observe that leaf cells that intersect -// shape s can be used in the covering, since the covering of shape x must -// also cover these cells. This allows us to compute non-empty coverings for -// all shapes. Since this is not technically the interior covering, we use the -// term "inner covering". -func innerCovering(rc *s2.RegionCoverer, regions []s2.Region) s2.CellUnion { - var u s2.CellUnion - for _, r := range regions { - switch region := r.(type) { - case s2.Point: - cellID := cellIDCoveringPoint(region, rc.MaxLevel) - u = append(u, cellID) - case *s2.Polyline: - // TODO(sumeer): for long lines could also pick some intermediate - // points along the line. Decide based on experiments. - for _, p := range *region { - cellID := cellIDCoveringPoint(p, rc.MaxLevel) - u = append(u, cellID) - } - case *s2.Polygon: - // Iterate over all exterior points - if region.NumLoops() > 0 { - loop := region.Loop(0) - for _, p := range loop.Vertices() { - cellID := cellIDCoveringPoint(p, rc.MaxLevel) - u = append(u, cellID) - } - // Arbitrary threshold value. This is to avoid computing an expensive - // region covering for regions with small area. - // TODO(sumeer): Improve this heuristic: - // - Area() may be expensive. - // - For large area regions, put an upper bound on the - // level used for cells. - // Decide based on experiments. - const smallPolygonLevelThreshold = 25 - if region.Area() > s2.AvgAreaMetric.Value(smallPolygonLevelThreshold) { - u = append(u, rc.InteriorCovering(region)...) - } - } - default: - panic("bug: code should not be producing unhandled Region type") - } - } - // Ensure the cells are non-overlapping. - u.Normalize() - - // TODO(sumeer): if the number of cells is too many, make the list sparser. - // u[len(u)-1] - u[0] / len(u) is the mean distance between cells. Pick a - // target distance based on the goal to reduce to k cells: target_distance - // := mean_distance * k / len(u) Then iterate over u and for every sequence - // of cells that are within target_distance, replace by median cell or by - // largest cell. Decide based on experiments. - - return u -} - -func cellIDCoveringPoint(point s2.Point, level int) s2.CellID { - cellID := s2.CellFromPoint(point).ID() - if !cellID.IsLeaf() { - panic("bug in S2") - } - return cellID.Parent(level) -} - -// ancestorCells returns the set of cells containing these cells, not -// including the given cells. -// -// TODO(sumeer): use the MinLevel and LevelMod of the RegionCoverer used -// for the index to constrain the ancestors set. -func ancestorCells(cells []s2.CellID) []s2.CellID { - var ancestors []s2.CellID - var seen map[s2.CellID]struct{} - if len(cells) > 1 { - seen = make(map[s2.CellID]struct{}) - } - for _, c := range cells { - for l := c.Level() - 1; l >= 0; l-- { - p := c.Parent(l) - if seen != nil { - if _, ok := seen[p]; ok { - break - } - seen[p] = struct{}{} - } - ancestors = append(ancestors, p) - } - } - return ancestors -} - -// Helper for InvertedIndexKeys. -func invertedIndexKeys(_ context.Context, rc covererInterface, r []s2.Region) []Key { - covering := rc.covering(r) - keys := make([]Key, len(covering)) - for i, cid := range covering { - keys[i] = Key(cid) - } - return keys -} - -// TODO(sumeer): examine RegionCoverer carefully to see if we can strengthen -// the covering invariant, which would increase the efficiency of covers() and -// remove the need for TestingInnerCovering(). -// -// Helper for Covers. -func covers(c context.Context, rc covererInterface, r []s2.Region) UnionKeySpans { - // We use intersects since geometries covered by r may have been indexed - // using cells that are ancestors of the covering of r. We could avoid - // reading ancestors if we had a stronger covering invariant, such as by - // indexing inner coverings. - return intersects(c, rc, r) -} - -// Helper for Intersects. Returns spans in sorted order for convenience of -// scans. -func intersects(_ context.Context, rc covererInterface, r []s2.Region) UnionKeySpans { - covering := rc.covering(r) - return intersectsUsingCovering(covering) -} - -func intersectsUsingCovering(covering s2.CellUnion) UnionKeySpans { - querySpans := make([]KeySpan, len(covering)) - for i, cid := range covering { - querySpans[i] = KeySpan{Start: Key(cid.RangeMin()), End: Key(cid.RangeMax())} - } - for _, cid := range ancestorCells(covering) { - querySpans = append(querySpans, KeySpan{Start: Key(cid), End: Key(cid)}) - } - sort.Slice(querySpans, func(i, j int) bool { return querySpans[i].Start < querySpans[j].Start }) - return querySpans -} - -// Helper for CoveredBy. -func coveredBy(_ context.Context, rc *s2.RegionCoverer, r []s2.Region) RPKeyExpr { - covering := innerCovering(rc, r) - ancestors := ancestorCells(covering) - - // The covering is normalized so no 2 cells are such that one is an ancestor - // of another. Arrange these cells and their ancestors in a quad-tree. Any cell - // with more than one child needs to be unioned with the intersection of the - // expressions corresponding to each child. See the detailed comment in - // generateRPExprForTree(). - - // It is sufficient to represent the tree(s) using presentCells since the ids - // of all possible 4 children of a cell can be computed and checked for - // presence in the map. - presentCells := make(map[s2.CellID]struct{}, len(covering)+len(ancestors)) - for _, c := range covering { - presentCells[c] = struct{}{} - } - for _, c := range ancestors { - presentCells[c] = struct{}{} - } - - // Construct the reverse polish expression. Note that there are up to 6 - // trees corresponding to the 6 faces in S2. The expressions for the - // trees need to be intersected with each other. - expr := make([]RPExprElement, 0, len(presentCells)*2) - numFaces := 0 - for face := 0; face < 6; face++ { - rootID := s2.CellIDFromFace(face) - if _, ok := presentCells[rootID]; !ok { - continue - } - expr = generateRPExprForTree(rootID, presentCells, expr) - numFaces++ - if numFaces > 1 { - expr = append(expr, RPSetIntersection) - } - } - return expr -} - -// The quad-trees stored in presentCells together represent a set expression. -// This expression specifies: -// - the path for each leaf to the root of that quad-tree. The index entries -// on each such path represent the shapes that cover that leaf. Hence these -// index entries for a single path need to be unioned to give the shapes -// that cover the leaf. -// - The full expression specifies the shapes that cover all the leaves, so -// the union expressions for the paths must be intersected with each other. -// -// Reusing an example from earlier in this file, say the quad-tree is: -// c0 -// | -// c3 -// | -// +---+---+ -// | | -// c13 c15 -// | | -// c53 +--+--+ -// | | -// c61 c64 -// -// This tree represents the following expression (where I(c) are the index -// entries stored at cell c): -// (I(c64) \union I(c15) \union I(c3) \union I(c0)) \intersection -// (I(c61) \union I(c15) \union I(c3) \union I(c0)) \intersection -// (I(c53) \union I(c13) \union I(c3) \union I(c0)) -// In this example all the union sub-expressions have the same number of terms -// but that does not need to be true. -// -// The above expression can be factored to eliminate repetition of the -// same cell. The factored expression for this example is: -// I(c0) \union I(c3) \union -// ((I(c13) \union I(c53)) \intersection -// (I(c15) \union (I(c61) \intersection I(c64))) -// -// This function generates this factored expression represented in reverse -// polish notation. -// -// One can generate the factored expression in reverse polish notation using -// a post-order traversal of this tree: -// Step A. append the expression for the subtree rooted at c3 -// Step B. append c0 and the union operator -// For Step A: -// - append the expression for the subtree rooted at c13 -// - append the expression for the subtree rooted at c15 -// - append the intersection operator -// - append c13 -// - append the union operator -func generateRPExprForTree( - rootID s2.CellID, presentCells map[s2.CellID]struct{}, expr []RPExprElement, -) []RPExprElement { - expr = append(expr, Key(rootID)) - if rootID.IsLeaf() { - return expr - } - numChildren := 0 - for _, childCellID := range rootID.Children() { - if _, ok := presentCells[childCellID]; !ok { - continue - } - expr = generateRPExprForTree(childCellID, presentCells, expr) - numChildren++ - if numChildren > 1 { - expr = append(expr, RPSetIntersection) - } - } - if numChildren > 0 { - expr = append(expr, RPSetUnion) - } - return expr -} - -// stringBuilderWithWrap is a strings.Builder that approximately wraps at a -// certain number of characters. Newline characters should only be -// written using tryWrap and doWrap. -type stringBuilderWithWrap struct { - *strings.Builder - wrap int - lastWrap int -} - -func newStringBuilderWithWrap(b *strings.Builder, wrap int) *stringBuilderWithWrap { - return &stringBuilderWithWrap{ - Builder: b, - wrap: wrap, - lastWrap: b.Len(), - } -} - -func (b *stringBuilderWithWrap) tryWrap() { - if b.Len()-b.lastWrap > b.wrap { - b.doWrap() - } -} - -func (b *stringBuilderWithWrap) doWrap() { - fmt.Fprintln(b) - b.lastWrap = b.Len() -} - -// DefaultS2Config returns the default S2Config to initialize. -func DefaultS2Config() *S2Config { - return &S2Config{ - MinLevel: 0, - MaxLevel: 30, - LevelMod: 1, - MaxCells: 4, - } -} diff --git a/postgres/parser/geo/geoindex/s2_geography_index.go b/postgres/parser/geo/geoindex/s2_geography_index.go deleted file mode 100644 index cb1b10ec3b..0000000000 --- a/postgres/parser/geo/geoindex/s2_geography_index.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geoindex - -import ( - "context" - - "github.com/cockroachdb/errors" - "github.com/golang/geo/s1" - "github.com/golang/geo/s2" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geogfn" - "github.com/dolthub/doltgresql/postgres/parser/geo/geoprojbase" -) - -// s2GeographyIndex is an implementation of GeographyIndex that uses the S2 geometry -// library. -type s2GeographyIndex struct { - rc *s2.RegionCoverer -} - -var _ GeographyIndex = (*s2GeographyIndex)(nil) - -// NewS2GeographyIndex returns an index with the given configuration. The -// configuration of an index cannot be changed without rewriting the index -// since deletes could miss some index entries. Currently, reads could use a -// different configuration, but that is subject to change if we manage to -// strengthen the covering invariants (see the todo in covers() in index.go). -func NewS2GeographyIndex(cfg S2GeographyConfig) GeographyIndex { - // TODO(sumeer): Sanity check cfg. - return &s2GeographyIndex{ - rc: &s2.RegionCoverer{ - MinLevel: int(cfg.S2Config.MinLevel), - MaxLevel: int(cfg.S2Config.MaxLevel), - LevelMod: int(cfg.S2Config.LevelMod), - MaxCells: int(cfg.S2Config.MaxCells), - }, - } -} - -// DefaultGeographyIndexConfig returns a default config for a geography index. -func DefaultGeographyIndexConfig() *Config { - return &Config{ - S2Geography: &S2GeographyConfig{S2Config: DefaultS2Config()}, - } -} - -// geogCovererWithBBoxFallback first computes the covering for the provided -// regions (which were computed using g), and if the covering is too broad -// (contains top-level cells from all faces), falls back to using the bounding -// box of g to compute the covering. -type geogCovererWithBBoxFallback struct { - rc *s2.RegionCoverer - g geo.Geography -} - -var _ covererInterface = geogCovererWithBBoxFallback{} - -func toDeg(radians float64) float64 { - return s1.Angle(radians).Degrees() -} - -func (rc geogCovererWithBBoxFallback) covering(regions []s2.Region) s2.CellUnion { - cu := simpleCovererImpl{rc: rc.rc}.covering(regions) - if isBadGeogCovering(cu) { - bbox := rc.g.SpatialObject().BoundingBox - if bbox == nil { - return cu - } - flatCoords := []float64{ - toDeg(bbox.LoX), toDeg(bbox.LoY), toDeg(bbox.HiX), toDeg(bbox.LoY), - toDeg(bbox.HiX), toDeg(bbox.HiY), toDeg(bbox.LoX), toDeg(bbox.HiY), - toDeg(bbox.LoX), toDeg(bbox.LoY)} - bboxT := geom.NewPolygonFlat(geom.XY, flatCoords, []int{len(flatCoords)}) - bboxRegions, err := geo.S2RegionsFromGeomT(bboxT, geo.EmptyBehaviorOmit) - if err != nil { - return cu - } - bboxCU := simpleCovererImpl{rc: rc.rc}.covering(bboxRegions) - if !isBadGeogCovering(bboxCU) { - cu = bboxCU - } - } - return cu -} - -func isBadGeogCovering(cu s2.CellUnion) bool { - const numFaces = 6 - if len(cu) != numFaces { - return false - } - numFaceCells := 0 - for _, c := range cu { - if c.Level() == 0 { - numFaceCells++ - } - } - return numFaces == numFaceCells -} - -// InvertedIndexKeys implements the GeographyIndex interface. -func (i *s2GeographyIndex) InvertedIndexKeys(c context.Context, g geo.Geography) ([]Key, error) { - r, err := g.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return nil, err - } - return invertedIndexKeys(c, geogCovererWithBBoxFallback{rc: i.rc, g: g}, r), nil -} - -// Covers implements the GeographyIndex interface. -func (i *s2GeographyIndex) Covers(c context.Context, g geo.Geography) (UnionKeySpans, error) { - r, err := g.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return nil, err - } - return covers(c, geogCovererWithBBoxFallback{rc: i.rc, g: g}, r), nil -} - -// CoveredBy implements the GeographyIndex interface. -func (i *s2GeographyIndex) CoveredBy(c context.Context, g geo.Geography) (RPKeyExpr, error) { - r, err := g.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return nil, err - } - return coveredBy(c, i.rc, r), nil -} - -// Intersects implements the GeographyIndex interface. -func (i *s2GeographyIndex) Intersects(c context.Context, g geo.Geography) (UnionKeySpans, error) { - r, err := g.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return nil, err - } - return intersects(c, geogCovererWithBBoxFallback{rc: i.rc, g: g}, r), nil -} - -func (i *s2GeographyIndex) DWithin( - _ context.Context, - g geo.Geography, - distanceMeters float64, - useSphereOrSpheroid geogfn.UseSphereOrSpheroid, -) (UnionKeySpans, error) { - projInfo, ok := geoprojbase.Projection(g.SRID()) - if !ok { - return nil, errors.Errorf("projection not found for SRID: %d", g.SRID()) - } - if projInfo.Spheroid == nil { - return nil, errors.Errorf("projection %d does not have spheroid", g.SRID()) - } - r, err := g.AsS2(geo.EmptyBehaviorOmit) - if err != nil { - return nil, err - } - // The following approach of constructing the covering and then expanding by - // an angle is worse than first expanding the original shape and then - // constructing a covering. However the s2 golang library lacks the c++ - // S2ShapeIndexBufferedRegion, whose GetCellUnionBound() method is what we - // desire. - // - // Construct the cell covering for the shape. - gCovering := geogCovererWithBBoxFallback{rc: i.rc, g: g}.covering(r) - // Convert the distanceMeters to an angle, in order to expand the cell covering - // on the sphere by the angle. - multiplier := 1.0 - if useSphereOrSpheroid == geogfn.UseSpheroid { - // We are using a sphere to calculate an angle on a spheroid, so adjust by the - // error. - multiplier += geogfn.SpheroidErrorFraction - } - angle := s1.Angle(multiplier * distanceMeters / projInfo.Spheroid.SphereRadius) - // maxLevelDiff puts a bound on the number of cells used after the expansion. - // For example, we do not want expanding a large country by 1km to generate too - // many cells. - const maxLevelDiff = 2 - gCovering.ExpandByRadius(angle, maxLevelDiff) - // Finally, make the expanded covering obey the configuration of the index, which - // is used in the RegionCoverer. - var covering s2.CellUnion - for _, c := range gCovering { - if c.Level() > i.rc.MaxLevel { - c = c.Parent(i.rc.MaxLevel) - } - covering = append(covering, c) - } - covering.Normalize() - return intersectsUsingCovering(covering), nil -} - -func (i *s2GeographyIndex) TestingInnerCovering(g geo.Geography) s2.CellUnion { - r, _ := g.AsS2(geo.EmptyBehaviorOmit) - if r == nil { - return nil - } - return innerCovering(i.rc, r) -} diff --git a/postgres/parser/geo/geoindex/s2_geometry_index.go b/postgres/parser/geo/geoindex/s2_geometry_index.go deleted file mode 100644 index 63d8dde129..0000000000 --- a/postgres/parser/geo/geoindex/s2_geometry_index.go +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geoindex - -import ( - "context" - - "github.com/cockroachdb/errors" - "github.com/golang/geo/r3" - "github.com/golang/geo/s2" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geomfn" - "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" - "github.com/dolthub/doltgresql/postgres/parser/geo/geoprojbase" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// s2GeometryIndex is an implementation of GeometryIndex that uses the S2 geometry -// library. -type s2GeometryIndex struct { - rc *s2.RegionCoverer - minX, maxX, minY, maxY float64 - deltaX, deltaY float64 -} - -var _ GeometryIndex = (*s2GeometryIndex)(nil) - -// We adjust the clipping bounds to be smaller by this fraction, since using -// the endpoints of face 0 in S2 causes coverings to spill out of that face. -const clippingBoundsDelta = 0.01 - -// NewS2GeometryIndex returns an index with the given configuration. All reads and -// writes on this index must use the same config. Writes must use the same -// config to correctly process deletions. Reads must use the same config since -// the bounds affect when a read needs to look at the exceedsBoundsCellID. -func NewS2GeometryIndex(cfg S2GeometryConfig) GeometryIndex { - // TODO(sumeer): Sanity check cfg. - return &s2GeometryIndex{ - rc: &s2.RegionCoverer{ - MinLevel: int(cfg.S2Config.MinLevel), - MaxLevel: int(cfg.S2Config.MaxLevel), - LevelMod: int(cfg.S2Config.LevelMod), - MaxCells: int(cfg.S2Config.MaxCells), - }, - minX: cfg.MinX, - maxX: cfg.MaxX, - minY: cfg.MinY, - maxY: cfg.MaxY, - deltaX: clippingBoundsDelta * (cfg.MaxX - cfg.MinX), - deltaY: clippingBoundsDelta * (cfg.MaxY - cfg.MinY), - } -} - -// TODO(sumeer): also support index config with parameters specified by CREATE -// INDEX. - -// DefaultGeometryIndexConfig returns a default config for a geometry index. -func DefaultGeometryIndexConfig() *Config { - return &Config{ - S2Geometry: &S2GeometryConfig{ - // Arbitrary bounding box. - MinX: -10000, - MaxX: 10000, - MinY: -10000, - MaxY: 10000, - S2Config: DefaultS2Config()}, - } -} - -// GeometryIndexConfigForSRID returns a geometry index config for srid. -func GeometryIndexConfigForSRID(srid geopb.SRID) (*Config, error) { - if srid == 0 { - return DefaultGeometryIndexConfig(), nil - } - p, exists := geoprojbase.Projection(srid) - if !exists { - return nil, errors.Newf("expected definition for SRID %d", srid) - } - b := p.Bounds - minX, maxX, minY, maxY := b.MinX, b.MaxX, b.MinY, b.MaxY - // There are projections where the min and max are equal e.g. 3571. - // We need to have a valid rectangle as the geometry index bounds. - if maxX-minX < 1 { - maxX++ - } - if maxY-minY < 1 { - maxY++ - } - // We are covering shapes using cells that are square. If we have shapes - // that start off as well-behaved wrt square cells, we do not wish to - // distort them significantly. Hence, we equalize MaxX-MinX and MaxY-MinY - // in the index bounds. - diffX := maxX - minX - diffY := maxY - minY - if diffX > diffY { - adjustment := (diffX - diffY) / 2 - minY -= adjustment - maxY += adjustment - } else { - adjustment := (diffY - diffX) / 2 - minX -= adjustment - maxX += adjustment - } - // Expand the bounds by 2x the clippingBoundsDelta, to - // ensure that shapes touching the bounds don't get - // clipped. - boundsExpansion := 2 * clippingBoundsDelta - deltaX := (maxX - minX) * boundsExpansion - deltaY := (maxY - minY) * boundsExpansion - return &Config{ - S2Geometry: &S2GeometryConfig{ - MinX: minX - deltaX, - MaxX: maxX + deltaX, - MinY: minY - deltaY, - MaxY: maxY + deltaY, - S2Config: DefaultS2Config()}, - }, nil -} - -// A cell id unused by S2. We use it to index geometries that exceed the -// configured bounds. -const exceedsBoundsCellID = s2.CellID(^uint64(0)) - -// TODO(sumeer): adjust code to handle precision issues with floating point -// arithmetic. - -// geomCovererWithBBoxFallback first computes the covering for the provided -// regions (which were computed using geom), and if the covering is too -// broad (contains faces other than 0), falls back to using the bounding -// box of geom to compute the covering. -type geomCovererWithBBoxFallback struct { - s *s2GeometryIndex - geom geom.T -} - -var _ covererInterface = geomCovererWithBBoxFallback{} - -func (rc geomCovererWithBBoxFallback) covering(regions []s2.Region) s2.CellUnion { - cu := simpleCovererImpl{rc: rc.s.rc}.covering(regions) - if isBadGeomCovering(cu) { - bbox := geo.BoundingBoxFromGeomTGeometryType(rc.geom) - flatCoords := []float64{ - bbox.LoX, bbox.LoY, bbox.HiX, bbox.LoY, bbox.HiX, bbox.HiY, bbox.LoX, bbox.HiY, - bbox.LoX, bbox.LoY} - bboxT := geom.NewPolygonFlat(geom.XY, flatCoords, []int{len(flatCoords)}) - bboxRegions := rc.s.s2RegionsFromPlanarGeomT(bboxT) - bboxCU := simpleCovererImpl{rc: rc.s.rc}.covering(bboxRegions) - if !isBadGeomCovering(bboxCU) { - cu = bboxCU - } - } - return cu -} - -func isBadGeomCovering(cu s2.CellUnion) bool { - for _, c := range cu { - if c.Face() != 0 { - // Good coverings should not see a face other than 0. - return true - } - } - return false -} - -// InvertedIndexKeys implements the GeometryIndex interface. -func (s *s2GeometryIndex) InvertedIndexKeys(c context.Context, g geo.Geometry) ([]Key, error) { - // If the geometry exceeds the bounds, we index the clipped geometry in - // addition to the special cell, so that queries for geometries that don't - // exceed the bounds don't need to query the special cell (which would - // become a hotspot in the key space). - gt, clipped, err := s.convertToGeomTAndTryClip(g) - if err != nil { - return nil, err - } - var keys []Key - if gt != nil { - r := s.s2RegionsFromPlanarGeomT(gt) - keys = invertedIndexKeys(c, geomCovererWithBBoxFallback{s: s, geom: gt}, r) - } - if clipped { - keys = append(keys, Key(exceedsBoundsCellID)) - } - return keys, nil -} - -// Covers implements the GeometryIndex interface. -func (s *s2GeometryIndex) Covers(c context.Context, g geo.Geometry) (UnionKeySpans, error) { - return s.Intersects(c, g) -} - -// CoveredBy implements the GeometryIndex interface. -func (s *s2GeometryIndex) CoveredBy(c context.Context, g geo.Geometry) (RPKeyExpr, error) { - // If the geometry exceeds the bounds, we use the clipped geometry to - // restrict the search within the bounds. - gt, clipped, err := s.convertToGeomTAndTryClip(g) - if err != nil { - return nil, err - } - var expr RPKeyExpr - if gt != nil { - r := s.s2RegionsFromPlanarGeomT(gt) - expr = coveredBy(c, s.rc, r) - } - if clipped { - // Intersect with the shapes that exceed the bounds. - expr = append(expr, Key(exceedsBoundsCellID)) - if len(expr) > 1 { - expr = append(expr, RPSetIntersection) - } - } - return expr, nil -} - -// Intersects implements the GeometryIndex interface. -func (s *s2GeometryIndex) Intersects(c context.Context, g geo.Geometry) (UnionKeySpans, error) { - // If the geometry exceeds the bounds, we use the clipped geometry to - // restrict the search within the bounds. - gt, clipped, err := s.convertToGeomTAndTryClip(g) - if err != nil { - return nil, err - } - var spans UnionKeySpans - if gt != nil { - r := s.s2RegionsFromPlanarGeomT(gt) - spans = intersects(c, geomCovererWithBBoxFallback{s: s, geom: gt}, r) - } - if clipped { - // And lookup all shapes that exceed the bounds. The exceedsBoundsCellID is the largest - // possible key, so appending it maintains the sorted order of spans. - spans = append(spans, KeySpan{Start: Key(exceedsBoundsCellID), End: Key(exceedsBoundsCellID)}) - } - return spans, nil -} - -func (s *s2GeometryIndex) DWithin( - c context.Context, g geo.Geometry, distance float64, -) (UnionKeySpans, error) { - // TODO(sumeer): are the default params the correct thing to use here? - g, err := geomfn.Buffer(g, geomfn.MakeDefaultBufferParams(), distance) - if err != nil { - return nil, err - } - return s.Intersects(c, g) -} - -func (s *s2GeometryIndex) DFullyWithin( - c context.Context, g geo.Geometry, distance float64, -) (UnionKeySpans, error) { - // TODO(sumeer): are the default params the correct thing to use here? - g, err := geomfn.Buffer(g, geomfn.MakeDefaultBufferParams(), distance) - if err != nil { - return nil, err - } - return s.Covers(c, g) -} - -// Converts to geom.T and clips to the rectangle bounds of the index. -func (s *s2GeometryIndex) convertToGeomTAndTryClip(g geo.Geometry) (geom.T, bool, error) { - gt, err := g.AsGeomT() - if err != nil { - return nil, false, err - } - if gt.Empty() { - return gt, false, nil - } - clipped := false - if s.geomExceedsBounds(gt) { - clipped = true - clippedEWKB, err := - geos.ClipByRect(g.EWKB(), s.minX+s.deltaX, s.minY+s.deltaY, s.maxX-s.deltaX, s.maxY-s.deltaY) - if err != nil { - return nil, false, err - } - gt = nil - if clippedEWKB != nil { - g, err = geo.ParseGeometryFromEWKBUnsafe(clippedEWKB) - if err != nil { - return nil, false, err - } - gt, err = g.AsGeomT() - if err != nil { - return nil, false, err - } - } - } - return gt, clipped, nil -} - -// Returns true if the point represented by (x, y) exceeds the rectangle -// bounds of the index. -func (s *s2GeometryIndex) xyExceedsBounds(x float64, y float64) bool { - if x < (s.minX+s.deltaX) || x > (s.maxX-s.deltaX) { - return true - } - if y < (s.minY+s.deltaY) || y > (s.maxY-s.deltaY) { - return true - } - return false -} - -// Returns true if g exceeds the rectangle bounds of the index. -func (s *s2GeometryIndex) geomExceedsBounds(g geom.T) bool { - switch repr := g.(type) { - case *geom.Point: - return s.xyExceedsBounds(repr.X(), repr.Y()) - case *geom.LineString: - for i := 0; i < repr.NumCoords(); i++ { - p := repr.Coord(i) - if s.xyExceedsBounds(p.X(), p.Y()) { - return true - } - } - case *geom.Polygon: - if repr.NumLinearRings() > 0 { - lr := repr.LinearRing(0) - for i := 0; i < lr.NumCoords(); i++ { - if s.xyExceedsBounds(lr.Coord(i).X(), lr.Coord(i).Y()) { - return true - } - } - } - case *geom.GeometryCollection: - for _, geom := range repr.Geoms() { - if s.geomExceedsBounds(geom) { - return true - } - } - case *geom.MultiPoint: - for i := 0; i < repr.NumPoints(); i++ { - if s.geomExceedsBounds(repr.Point(i)) { - return true - } - } - case *geom.MultiLineString: - for i := 0; i < repr.NumLineStrings(); i++ { - if s.geomExceedsBounds(repr.LineString(i)) { - return true - } - } - case *geom.MultiPolygon: - for i := 0; i < repr.NumPolygons(); i++ { - if s.geomExceedsBounds(repr.Polygon(i)) { - return true - } - } - } - return false -} - -// stToUV() and face0UVToXYZPoint() are adapted from unexported methods in -// github.com/golang/geo/s2/stuv.go - -// stToUV converts an s or t value to the corresponding u or v value. -// This is a non-linear transformation from [-1,1] to [-1,1] that -// attempts to make the cell sizes more uniform. -// This uses what the C++ version calls 'the quadratic transform'. -func stToUV(s float64) float64 { - if s >= 0.5 { - return (1 / 3.) * (4*s*s - 1) - } - return (1 / 3.) * (1 - 4*(1-s)*(1-s)) -} - -// Specialized version of faceUVToXYZ() for face 0 -func face0UVToXYZPoint(u, v float64) s2.Point { - return s2.Point{Vector: r3.Vector{X: 1, Y: u, Z: v}} -} - -func (s *s2GeometryIndex) planarPointToS2Point(x float64, y float64) s2.Point { - ss := (x - s.minX) / (s.maxX - s.minX) - tt := (y - s.minY) / (s.maxY - s.minY) - u := stToUV(ss) - v := stToUV(tt) - return face0UVToXYZPoint(u, v) -} - -// TODO(sumeer): this is similar to S2RegionsFromGeomT() but needs to do -// a different point conversion. If these functions do not diverge further, -// and turn out not to be performance critical, merge the two implementations. -func (s *s2GeometryIndex) s2RegionsFromPlanarGeomT(geomRepr geom.T) []s2.Region { - if geomRepr.Empty() { - return nil - } - var regions []s2.Region - switch repr := geomRepr.(type) { - case *geom.Point: - regions = []s2.Region{ - s.planarPointToS2Point(repr.X(), repr.Y()), - } - case *geom.LineString: - points := make([]s2.Point, repr.NumCoords()) - for i := 0; i < repr.NumCoords(); i++ { - p := repr.Coord(i) - points[i] = s.planarPointToS2Point(p.X(), p.Y()) - } - pl := s2.Polyline(points) - regions = []s2.Region{&pl} - case *geom.Polygon: - loops := make([]*s2.Loop, repr.NumLinearRings()) - // The first ring is a "shell". Following rings are "holes". - // All loops must be oriented CCW for S2. - for ringIdx := 0; ringIdx < repr.NumLinearRings(); ringIdx++ { - linearRing := repr.LinearRing(ringIdx) - points := make([]s2.Point, linearRing.NumCoords()) - isCCW := geo.IsLinearRingCCW(linearRing) - for pointIdx := 0; pointIdx < linearRing.NumCoords(); pointIdx++ { - p := linearRing.Coord(pointIdx) - pt := s.planarPointToS2Point(p.X(), p.Y()) - if isCCW { - points[pointIdx] = pt - } else { - points[len(points)-pointIdx-1] = pt - } - } - loops[ringIdx] = s2.LoopFromPoints(points) - } - regions = []s2.Region{ - s2.PolygonFromLoops(loops), - } - case *geom.GeometryCollection: - for _, geom := range repr.Geoms() { - regions = append(regions, s.s2RegionsFromPlanarGeomT(geom)...) - } - case *geom.MultiPoint: - for i := 0; i < repr.NumPoints(); i++ { - regions = append(regions, s.s2RegionsFromPlanarGeomT(repr.Point(i))...) - } - case *geom.MultiLineString: - for i := 0; i < repr.NumLineStrings(); i++ { - regions = append(regions, s.s2RegionsFromPlanarGeomT(repr.LineString(i))...) - } - case *geom.MultiPolygon: - for i := 0; i < repr.NumPolygons(); i++ { - regions = append(regions, s.s2RegionsFromPlanarGeomT(repr.Polygon(i))...) - } - } - return regions -} - -func (s *s2GeometryIndex) TestingInnerCovering(g geo.Geometry) s2.CellUnion { - gt, _, err := s.convertToGeomTAndTryClip(g) - if err != nil || gt == nil { - return nil - } - r := s.s2RegionsFromPlanarGeomT(gt) - return innerCovering(s.rc, r) -} diff --git a/postgres/parser/geo/geomfn/affine_transforms.go b/postgres/parser/geo/geomfn/affine_transforms.go deleted file mode 100644 index 8da601e118..0000000000 --- a/postgres/parser/geo/geomfn/affine_transforms.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// AffineMatrix defines an affine transformation matrix for a geom object. -// It is expected to be of the form: -// a b c x_off -// d e f y_off -// g h i z_off -// 0 0 0 1 -// Which gets applies onto a coordinate of form: -// (x y z 0)^T -// With the following transformation: -// x' = a*x + b*y + c*z + x_off -// y' = d*x + e*y + f*z + y_off -// z' = g*x + h*y + i*z + z_off -type AffineMatrix [][]float64 - -// Affine applies a 3D affine transformation onto the given geometry. -// See: https://en.wikipedia.org/wiki/Affine_transformation. -func Affine(g geo.Geometry, m AffineMatrix) (geo.Geometry, error) { - if g.Empty() { - return g, nil - } - - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - newT, err := affine(t, m) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(newT) -} - -func affine(t geom.T, m AffineMatrix) (geom.T, error) { - return applyOnCoordsForGeomT(t, func(l geom.Layout, dst, src []float64) error { - var z float64 - if l.ZIndex() != -1 { - z = src[l.ZIndex()] - } - newX := m[0][0]*src[0] + m[0][1]*src[1] + m[0][2]*z + m[0][3] - newY := m[1][0]*src[0] + m[1][1]*src[1] + m[1][2]*z + m[1][3] - newZ := m[2][0]*src[0] + m[2][1]*src[1] + m[2][2]*z + m[2][3] - - dst[0] = newX - dst[1] = newY - if l.ZIndex() != -1 { - dst[2] = newZ - } - if l.MIndex() != -1 { - dst[l.MIndex()] = src[l.MIndex()] - } - return nil - }) -} - -// Translate returns a modified Geometry whose coordinates are incremented -// or decremented by the deltas. -func Translate(g geo.Geometry, deltas []float64) (geo.Geometry, error) { - if g.Empty() { - return g, nil - } - - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - newT, err := translate(t, deltas) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(newT) -} - -func translate(t geom.T, deltas []float64) (geom.T, error) { - if t.Layout().Stride() != len(deltas) { - err := geom.ErrStrideMismatch{ - Got: len(deltas), - Want: t.Layout().Stride(), - } - return nil, errors.Wrap(err, "translating coordinates") - } - var zOff float64 - if t.Layout().ZIndex() != -1 { - zOff = deltas[t.Layout().ZIndex()] - } - return affine( - t, - AffineMatrix([][]float64{ - {1, 0, 0, deltas[0]}, - {0, 1, 0, deltas[1]}, - {0, 0, 1, zOff}, - {0, 0, 0, 1}, - }), - ) -} - -// Scale returns a modified Geometry whose coordinates are multiplied by the factors. -func Scale(g geo.Geometry, factors []float64) (geo.Geometry, error) { - var zFactor float64 - if len(factors) > 2 { - zFactor = factors[2] - } - return Affine( - g, - AffineMatrix([][]float64{ - {factors[0], 0, 0, 0}, - {0, factors[1], 0, 0}, - {0, 0, zFactor, 0}, - {0, 0, 0, 1}, - }), - ) -} - -// ScaleRelativeToOrigin returns a modified Geometry whose coordinates are multiplied by the factors relative to the origin -func ScaleRelativeToOrigin( - g geo.Geometry, factor geo.Geometry, origin geo.Geometry, -) (geo.Geometry, error) { - if g.Empty() { - return g, nil - } - - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - factorG, err := factor.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - factorPointG, ok := factorG.(*geom.Point) - if !ok { - return geo.Geometry{}, errors.Newf("the scaling factor must be a Point") - } - - originG, err := origin.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - originPointG, ok := originG.(*geom.Point) - if !ok { - return geo.Geometry{}, errors.Newf("the false origin must be a Point") - } - - if factorG.Stride() != originG.Stride() { - err := geom.ErrStrideMismatch{ - Got: factorG.Stride(), - Want: originG.Stride(), - } - return geo.Geometry{}, errors.Wrap(err, "number of dimensions for the scaling factor and origin must be equal") - } - - // Offset by the origin, scale, and translate it back to the origin. - offsetDeltas := make([]float64, 0, 3) - offsetDeltas = append(offsetDeltas, -originPointG.X(), -originPointG.Y()) - if originG.Layout().ZIndex() != -1 { - offsetDeltas = append(offsetDeltas, -originPointG.Z()) - } - retT, err := translate(t, offsetDeltas) - if err != nil { - return geo.Geometry{}, err - } - - xFactor, yFactor := factorPointG.X(), factorPointG.Y() - var zFactor float64 = 1 - if factorPointG.Layout().ZIndex() != -1 { - zFactor = factorPointG.Z() - } - retT, err = affine( - retT, - AffineMatrix([][]float64{ - {xFactor, 0, 0, 0}, - {0, yFactor, 0, 0}, - {0, 0, zFactor, 0}, - {0, 0, 0, 1}, - }), - ) - if err != nil { - return geo.Geometry{}, err - } - - for i := range offsetDeltas { - offsetDeltas[i] = -offsetDeltas[i] - } - retT, err = translate(retT, offsetDeltas) - if err != nil { - return geo.Geometry{}, err - } - - return geo.MakeGeometryFromGeomT(retT) -} - -// Rotate returns a modified Geometry whose coordinates are rotated -// around the origin by a rotation angle. -func Rotate(g geo.Geometry, rotRadians float64) (geo.Geometry, error) { - return Affine( - g, - AffineMatrix([][]float64{ - {math.Cos(rotRadians), -math.Sin(rotRadians), 0, 0}, - {math.Sin(rotRadians), math.Cos(rotRadians), 0, 0}, - {0, 0, 1, 0}, - {0, 0, 0, 1}, - }), - ) -} diff --git a/postgres/parser/geo/geomfn/azimuth.go b/postgres/parser/geo/geomfn/azimuth.go deleted file mode 100644 index f57762d36e..0000000000 --- a/postgres/parser/geo/geomfn/azimuth.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// Azimuth returns the azimuth in radians of the segment defined by the given point geometries, -// where point a is the reference point. -// The reference direction from which the azimuth is calculated is north, and is positive clockwise. -// i.e. North = 0; East = Ï€/2; South = Ï€; West = 3Ï€/2. -// See https://en.wikipedia.org/wiki/Polar_coordinate_system. -// Returns nil if the two points are the same. -// Returns an error if any of the two Geometry items are not points. -func Azimuth(a geo.Geometry, b geo.Geometry) (*float64, error) { - if a.SRID() != b.SRID() { - return nil, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - - aGeomT, err := a.AsGeomT() - if err != nil { - return nil, err - } - - aPoint, ok := aGeomT.(*geom.Point) - if !ok { - return nil, errors.Newf("arguments must be POINT geometries") - } - - bGeomT, err := b.AsGeomT() - if err != nil { - return nil, err - } - - bPoint, ok := bGeomT.(*geom.Point) - if !ok { - return nil, errors.Newf("arguments must be POINT geometries") - } - - if aPoint.Empty() || bPoint.Empty() { - return nil, errors.Newf("cannot call ST_Azimuth with POINT EMPTY") - } - - if aPoint.X() == bPoint.X() && aPoint.Y() == bPoint.Y() { - return nil, nil - } - - atan := math.Atan2(bPoint.Y()-aPoint.Y(), bPoint.X()-aPoint.X()) - // math.Pi / 2 is North from the atan calculation this is a CCW direction. - // We want to return a CW direction, so subtract atan from math.Pi / 2 to get it into a CW direction. - // Then add 2*math.Pi to ensure a positive azimuth. - azimuth := math.Mod(math.Pi/2-atan+2*math.Pi, 2*math.Pi) - return &azimuth, nil -} diff --git a/postgres/parser/geo/geomfn/binary_predicates.go b/postgres/parser/geo/geomfn/binary_predicates.go deleted file mode 100644 index 11bd871f32..0000000000 --- a/postgres/parser/geo/geomfn/binary_predicates.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// Covers returns whether geometry A covers geometry B. -func Covers(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !a.CartesianBoundingBox().Covers(b.CartesianBoundingBox()) { - return false, nil - } - return geos.Covers(a.EWKB(), b.EWKB()) -} - -// CoveredBy returns whether geometry A is covered by geometry B. -func CoveredBy(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !b.CartesianBoundingBox().Covers(a.CartesianBoundingBox()) { - return false, nil - } - return geos.CoveredBy(a.EWKB(), b.EWKB()) -} - -// Contains returns whether geometry A contains geometry B. -func Contains(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !a.CartesianBoundingBox().Covers(b.CartesianBoundingBox()) { - return false, nil - } - return geos.Contains(a.EWKB(), b.EWKB()) -} - -// ContainsProperly returns whether geometry A properly contains geometry B. -func ContainsProperly(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !a.CartesianBoundingBox().Covers(b.CartesianBoundingBox()) { - return false, nil - } - return geos.RelatePattern(a.EWKB(), b.EWKB(), "T**FF*FF*") -} - -// Crosses returns whether geometry A crosses geometry B. -func Crosses(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !a.CartesianBoundingBox().Intersects(b.CartesianBoundingBox()) { - return false, nil - } - return geos.Crosses(a.EWKB(), b.EWKB()) -} - -// Disjoint returns whether geometry A is disjoint from geometry B. -func Disjoint(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - return geos.Disjoint(a.EWKB(), b.EWKB()) -} - -// Equals returns whether geometry A equals geometry B. -func Equals(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - // Empty items are equal to each other. - // Do this check before the BoundingBoxIntersects check, as we would otherwise - // return false. - if a.Empty() && b.Empty() { - return true, nil - } - if !a.CartesianBoundingBox().Covers(b.CartesianBoundingBox()) { - return false, nil - } - return geos.Equals(a.EWKB(), b.EWKB()) -} - -// Intersects returns whether geometry A intersects geometry B. -func Intersects(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !a.CartesianBoundingBox().Intersects(b.CartesianBoundingBox()) { - return false, nil - } - return geos.Intersects(a.EWKB(), b.EWKB()) -} - -// Overlaps returns whether geometry A overlaps geometry B. -func Overlaps(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !a.CartesianBoundingBox().Intersects(b.CartesianBoundingBox()) { - return false, nil - } - return geos.Overlaps(a.EWKB(), b.EWKB()) -} - -// Touches returns whether geometry A touches geometry B. -func Touches(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !a.CartesianBoundingBox().Intersects(b.CartesianBoundingBox()) { - return false, nil - } - return geos.Touches(a.EWKB(), b.EWKB()) -} - -// Within returns whether geometry A is within geometry B. -func Within(a geo.Geometry, b geo.Geometry) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if !b.CartesianBoundingBox().Covers(a.CartesianBoundingBox()) { - return false, nil - } - return geos.Within(a.EWKB(), b.EWKB()) -} diff --git a/postgres/parser/geo/geomfn/buffer.go b/postgres/parser/geo/geomfn/buffer.go deleted file mode 100644 index 4d1b556148..0000000000 --- a/postgres/parser/geo/geomfn/buffer.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "strconv" - "strings" - - "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// BufferParams is a wrapper around the geos.BufferParams. -type BufferParams struct { - p geos.BufferParams -} - -// MakeDefaultBufferParams returns the default BufferParams/ -func MakeDefaultBufferParams() BufferParams { - return BufferParams{ - p: geos.BufferParams{ - EndCapStyle: geos.BufferParamsEndCapStyleRound, - JoinStyle: geos.BufferParamsJoinStyleRound, - SingleSided: false, - QuadrantSegments: 8, - MitreLimit: 5.0, - }, - } -} - -// WithQuadrantSegments returns a copy of the BufferParams with the quadrantSegments set. -func (b BufferParams) WithQuadrantSegments(quadrantSegments int) BufferParams { - ret := b - ret.p.QuadrantSegments = quadrantSegments - return ret -} - -// ParseBufferParams parses the given buffer params from a SQL string into -// the BufferParams form. -// The string must be of the same format as specified by https://postgis.net/docs/ST_Buffer.html. -// Returns the BufferParams, as well as the modified distance. -func ParseBufferParams(s string, distance float64) (BufferParams, float64, error) { - p := MakeDefaultBufferParams() - fields := strings.Fields(s) - for _, field := range fields { - fParams := strings.Split(field, "=") - if len(fParams) != 2 { - return BufferParams{}, 0, errors.Newf("unknown buffer parameter: %s", fParams) - } - f, val := fParams[0], fParams[1] - switch strings.ToLower(f) { - case "quad_segs": - valInt, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return BufferParams{}, 0, errors.Wrapf(err, "invalid int for %s: %s", f, val) - } - p.p.QuadrantSegments = int(valInt) - case "endcap": - switch strings.ToLower(val) { - case "round": - p.p.EndCapStyle = geos.BufferParamsEndCapStyleRound - case "flat", "butt": - p.p.EndCapStyle = geos.BufferParamsEndCapStyleFlat - case "square": - p.p.EndCapStyle = geos.BufferParamsEndCapStyleSquare - default: - return BufferParams{}, 0, errors.Newf("unknown endcap: %s (accepted: round, flat, square)", val) - } - case "join": - switch strings.ToLower(val) { - case "round": - p.p.JoinStyle = geos.BufferParamsJoinStyleRound - case "mitre", "miter": - p.p.JoinStyle = geos.BufferParamsJoinStyleMitre - case "bevel": - p.p.JoinStyle = geos.BufferParamsJoinStyleBevel - default: - return BufferParams{}, 0, errors.Newf("unknown join: %s (accepted: round, mitre, bevel)", val) - } - case "mitre_limit", "miter_limit": - valFloat, err := strconv.ParseFloat(val, 64) - if err != nil { - return BufferParams{}, 0, errors.Wrapf(err, "invalid float for %s: %s", f, val) - } - p.p.MitreLimit = valFloat - case "side": - switch strings.ToLower(val) { - case "both": - p.p.SingleSided = false - case "left": - p.p.SingleSided = true - case "right": - p.p.SingleSided = true - distance *= -1 - default: - return BufferParams{}, 0, errors.Newf("unknown side: %s (accepted: both, left, right)", val) - } - default: - return BufferParams{}, 0, errors.Newf("unknown field: %s (accepted fields: quad_segs, endcap, join, mitre_limit, side)", f) - } - } - return p, distance, nil -} - -// Buffer buffers a given Geometry by the supplied parameters. -func Buffer(g geo.Geometry, params BufferParams, distance float64) (geo.Geometry, error) { - bufferedGeom, err := geos.Buffer(g.EWKB(), params.p, distance) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(bufferedGeom) -} diff --git a/postgres/parser/geo/geomfn/collections.go b/postgres/parser/geo/geomfn/collections.go deleted file mode 100644 index 94b7367886..0000000000 --- a/postgres/parser/geo/geomfn/collections.go +++ /dev/null @@ -1,440 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" -) - -// Collect collects two geometries into a GeometryCollection or multi-type. -// -// This is the binary version of `ST_Collect()`, but since it's not possible to -// have an aggregate and non-aggregate function with the same name (different -// args), this is not used. Code is left behind for when we add support for -// this. Be sure to handle NULL args when adding the builtin for this, where it -// should return the non-NULL arg unused like PostGIS. -func Collect(g1 geo.Geometry, g2 geo.Geometry) (geo.Geometry, error) { - t1, err := g1.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - t2, err := g2.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - // First, try to generate multi-types - switch t1 := t1.(type) { - case *geom.Point: - if t2, ok := t2.(*geom.Point); ok { - multi := geom.NewMultiPoint(t1.Layout()).SetSRID(t1.SRID()) - if err := multi.Push(t1); err != nil { - return geo.Geometry{}, err - } - if err := multi.Push(t2); err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(multi) - } - case *geom.LineString: - if t2, ok := t2.(*geom.LineString); ok { - multi := geom.NewMultiLineString(t1.Layout()).SetSRID(t1.SRID()) - if err := multi.Push(t1); err != nil { - return geo.Geometry{}, err - } - if err := multi.Push(t2); err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(multi) - } - case *geom.Polygon: - if t2, ok := t2.(*geom.Polygon); ok { - multi := geom.NewMultiPolygon(t1.Layout()).SetSRID(t1.SRID()) - if err := multi.Push(t1); err != nil { - return geo.Geometry{}, err - } - if err := multi.Push(t2); err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(multi) - } - } - - // Otherwise, just put them in a collection - gc := geom.NewGeometryCollection().SetSRID(t1.SRID()) - if err := gc.Push(t1); err != nil { - return geo.Geometry{}, err - } - if err := gc.Push(t2); err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(gc) -} - -// CollectionExtract returns a (multi-)geometry consisting only of the specified type. -// The type can only be point, line, or polygon. -func CollectionExtract(g geo.Geometry, shapeType geopb.ShapeType) (geo.Geometry, error) { - switch shapeType { - case geopb.ShapeType_Point, geopb.ShapeType_LineString, geopb.ShapeType_Polygon: - default: - return geo.Geometry{}, errors.Newf("only point, linestring and polygon may be extracted (got %s)", - shapeType) - } - - // If the input is already of the correct (multi-)type, just return it before - // decoding the geom.T below. - if g.ShapeType() == shapeType || g.ShapeType() == shapeType.MultiType() { - return g, nil - } - - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - switch t := t.(type) { - // If the input is not a collection then return an empty geometry of the expected type. - case *geom.Point, *geom.LineString, *geom.Polygon: - switch shapeType { - case geopb.ShapeType_Point: - return geo.MakeGeometryFromGeomT(geom.NewPointEmpty(t.Layout()).SetSRID(t.SRID())) - case geopb.ShapeType_LineString: - return geo.MakeGeometryFromGeomT(geom.NewLineString(t.Layout()).SetSRID(t.SRID())) - case geopb.ShapeType_Polygon: - return geo.MakeGeometryFromGeomT(geom.NewPolygon(t.Layout()).SetSRID(t.SRID())) - default: - return geo.Geometry{}, errors.AssertionFailedf("unexpected shape type %v", shapeType.String()) - } - - // If the input is a multitype then return an empty multi-geometry of the expected type. - case *geom.MultiPoint, *geom.MultiLineString, *geom.MultiPolygon: - switch shapeType.MultiType() { - case geopb.ShapeType_MultiPoint: - return geo.MakeGeometryFromGeomT(geom.NewMultiPoint(t.Layout()).SetSRID(t.SRID())) - case geopb.ShapeType_MultiLineString: - return geo.MakeGeometryFromGeomT(geom.NewMultiLineString(t.Layout()).SetSRID(t.SRID())) - case geopb.ShapeType_MultiPolygon: - return geo.MakeGeometryFromGeomT(geom.NewMultiPolygon(t.Layout()).SetSRID(t.SRID())) - default: - return geo.Geometry{}, errors.AssertionFailedf("unexpected shape type %v", shapeType.MultiType().String()) - } - - // If the input is a collection, recursively gather geometries of the right type. - case *geom.GeometryCollection: - // Empty geos.GeometryCollection has NoLayout, while PostGIS uses XY. Returned - // multi-geometries cannot have NoLayout. - layout := t.Layout() - if layout == geom.NoLayout && t.Empty() { - layout = geom.XY - } - iter := geo.NewGeomTIterator(t, geo.EmptyBehaviorOmit) - srid := t.SRID() - - var ( - multi geom.T - err error - ) - switch shapeType { - case geopb.ShapeType_Point: - multi, err = collectionExtractPoints(iter, layout, srid) - case geopb.ShapeType_LineString: - multi, err = collectionExtractLineStrings(iter, layout, srid) - case geopb.ShapeType_Polygon: - multi, err = collectionExtractPolygons(iter, layout, srid) - default: - return geo.Geometry{}, errors.AssertionFailedf("unexpected shape type %v", shapeType.String()) - } - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(multi) - - default: - return geo.Geometry{}, errors.AssertionFailedf("unexpected shape type: %T", t) - } -} - -// collectionExtractPoints extracts points from an iterator. -func collectionExtractPoints( - iter geo.GeomTIterator, layout geom.Layout, srid int, -) (*geom.MultiPoint, error) { - points := geom.NewMultiPoint(layout).SetSRID(srid) - for { - if next, hasNext, err := iter.Next(); err != nil { - return nil, err - } else if !hasNext { - break - } else if point, ok := next.(*geom.Point); ok { - if err = points.Push(point); err != nil { - return nil, err - } - } - } - return points, nil -} - -// collectionExtractLineStrings extracts line strings from an iterator. -func collectionExtractLineStrings( - iter geo.GeomTIterator, layout geom.Layout, srid int, -) (*geom.MultiLineString, error) { - lineStrings := geom.NewMultiLineString(layout).SetSRID(srid) - for { - if next, hasNext, err := iter.Next(); err != nil { - return nil, err - } else if !hasNext { - break - } else if lineString, ok := next.(*geom.LineString); ok { - if err = lineStrings.Push(lineString); err != nil { - return nil, err - } - } - } - return lineStrings, nil -} - -// collectionExtractPolygons extracts polygons from an iterator. -func collectionExtractPolygons( - iter geo.GeomTIterator, layout geom.Layout, srid int, -) (*geom.MultiPolygon, error) { - polygons := geom.NewMultiPolygon(layout).SetSRID(srid) - for { - if next, hasNext, err := iter.Next(); err != nil { - return nil, err - } else if !hasNext { - break - } else if polygon, ok := next.(*geom.Polygon); ok { - if err = polygons.Push(polygon); err != nil { - return nil, err - } - } - } - return polygons, nil -} - -// CollectionHomogenize returns the simplest representation of a collection. -func CollectionHomogenize(g geo.Geometry) (geo.Geometry, error) { - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - srid := t.SRID() - t, err = collectionHomogenizeGeomT(t) - if err != nil { - return geo.Geometry{}, err - } - if srid != 0 { - geo.AdjustGeomTSRID(t, geopb.SRID(srid)) - } - return geo.MakeGeometryFromGeomT(t) -} - -// collectionHomogenizeGeomT homogenizes a geom.T collection. -func collectionHomogenizeGeomT(t geom.T) (geom.T, error) { - switch t := t.(type) { - case *geom.Point, *geom.LineString, *geom.Polygon: - return t, nil - - case *geom.MultiPoint: - if t.NumPoints() == 1 { - return t.Point(0), nil - } - return t, nil - - case *geom.MultiLineString: - if t.NumLineStrings() == 1 { - return t.LineString(0), nil - } - return t, nil - - case *geom.MultiPolygon: - if t.NumPolygons() == 1 { - return t.Polygon(0), nil - } - return t, nil - - case *geom.GeometryCollection: - layout := t.Layout() - if layout == geom.NoLayout && t.Empty() { - layout = geom.XY - } - points := geom.NewMultiPoint(layout) - linestrings := geom.NewMultiLineString(layout) - polygons := geom.NewMultiPolygon(layout) - iter := geo.NewGeomTIterator(t, geo.EmptyBehaviorOmit) - for { - next, hasNext, err := iter.Next() - if err != nil { - return nil, err - } - if !hasNext { - break - } - switch next := next.(type) { - case *geom.Point: - err = points.Push(next) - case *geom.LineString: - err = linestrings.Push(next) - case *geom.Polygon: - err = polygons.Push(next) - default: - err = errors.AssertionFailedf("encountered unexpected geometry type: %T", next) - } - if err != nil { - return nil, err - } - } - homog := geom.NewGeometryCollection() - switch points.NumPoints() { - case 0: - case 1: - if err := homog.Push(points.Point(0)); err != nil { - return nil, err - } - default: - if err := homog.Push(points); err != nil { - return nil, err - } - } - switch linestrings.NumLineStrings() { - case 0: - case 1: - if err := homog.Push(linestrings.LineString(0)); err != nil { - return nil, err - } - default: - if err := homog.Push(linestrings); err != nil { - return nil, err - } - } - switch polygons.NumPolygons() { - case 0: - case 1: - if err := homog.Push(polygons.Polygon(0)); err != nil { - return nil, err - } - default: - if err := homog.Push(polygons); err != nil { - return nil, err - } - } - - if homog.NumGeoms() == 1 { - return homog.Geom(0), nil - } - return homog, nil - - default: - return nil, errors.AssertionFailedf("unknown geometry type: %T", t) - } -} - -// ForceCollection converts the input into a GeometryCollection. -func ForceCollection(g geo.Geometry) (geo.Geometry, error) { - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - t, err = forceCollectionFromGeomT(t) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(t) -} - -// forceCollectionFromGeomT converts a geom.T into a geom.GeometryCollection. -func forceCollectionFromGeomT(t geom.T) (geom.T, error) { - gc := geom.NewGeometryCollection().SetSRID(t.SRID()) - switch t := t.(type) { - case *geom.Point, *geom.LineString, *geom.Polygon: - if err := gc.Push(t); err != nil { - return nil, err - } - case *geom.MultiPoint: - for i := 0; i < t.NumPoints(); i++ { - if err := gc.Push(t.Point(i)); err != nil { - return nil, err - } - } - case *geom.MultiLineString: - for i := 0; i < t.NumLineStrings(); i++ { - if err := gc.Push(t.LineString(i)); err != nil { - return nil, err - } - } - case *geom.MultiPolygon: - for i := 0; i < t.NumPolygons(); i++ { - if err := gc.Push(t.Polygon(i)); err != nil { - return nil, err - } - } - case *geom.GeometryCollection: - gc = t - default: - return nil, errors.AssertionFailedf("unknown geometry type: %T", t) - } - return gc, nil -} - -// Multi converts the given geometry into a new multi-geometry. -func Multi(g geo.Geometry) (geo.Geometry, error) { - t, err := g.AsGeomT() // implicitly clones the input - if err != nil { - return geo.Geometry{}, err - } - switch t := t.(type) { - case *geom.MultiPoint, *geom.MultiLineString, *geom.MultiPolygon, *geom.GeometryCollection: - return geo.MakeGeometryFromGeomT(t) - case *geom.Point: - multi := geom.NewMultiPoint(t.Layout()).SetSRID(t.SRID()) - if !t.Empty() { - if err = multi.Push(t); err != nil { - return geo.Geometry{}, err - } - } - return geo.MakeGeometryFromGeomT(multi) - case *geom.LineString: - multi := geom.NewMultiLineString(t.Layout()).SetSRID(t.SRID()) - if !t.Empty() { - if err = multi.Push(t); err != nil { - return geo.Geometry{}, err - } - } - return geo.MakeGeometryFromGeomT(multi) - case *geom.Polygon: - multi := geom.NewMultiPolygon(t.Layout()).SetSRID(t.SRID()) - if !t.Empty() { - if err = multi.Push(t); err != nil { - return geo.Geometry{}, err - } - } - return geo.MakeGeometryFromGeomT(multi) - default: - return geo.Geometry{}, errors.AssertionFailedf("unknown geometry type: %T", t) - } -} diff --git a/postgres/parser/geo/geomfn/coord.go b/postgres/parser/geo/geomfn/coord.go deleted file mode 100644 index ef11d2b578..0000000000 --- a/postgres/parser/geo/geomfn/coord.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "math" - - "github.com/twpayne/go-geom" -) - -// coordAdd adds two coordinates and returns a new result. -func coordAdd(a geom.Coord, b geom.Coord) geom.Coord { - return geom.Coord{a.X() + b.X(), a.Y() + b.Y()} -} - -// coordSub subtracts two coordinates and returns a new result. -func coordSub(a geom.Coord, b geom.Coord) geom.Coord { - return geom.Coord{a.X() - b.X(), a.Y() - b.Y()} -} - -// coordMul multiplies a coord by a scalar and returns the new result. -func coordMul(a geom.Coord, s float64) geom.Coord { - return geom.Coord{a.X() * s, a.Y() * s} -} - -// coordDot returns the dot product of two coords if the coord was a vector. -func coordDot(a geom.Coord, b geom.Coord) float64 { - return a.X()*b.X() + a.Y()*b.Y() -} - -// coordNorm2 returns the normalization^2 of a coordinate if the coord was a vector. -func coordNorm2(c geom.Coord) float64 { - return coordDot(c, c) -} - -// coordNorm returns the normalization of a coordinate if the coord was a vector. -func coordNorm(c geom.Coord) float64 { - return math.Sqrt(coordNorm2(c)) -} - -// coordEqual returns whether two coordinates are equal. -func coordEqual(a geom.Coord, b geom.Coord) bool { - return a.X() == b.X() && a.Y() == b.Y() -} - -// coordMag2 returns the magnitude^2 of a coordinate if the coord was a vector. -func coordMag2(c geom.Coord) float64 { - return coordDot(c, c) -} diff --git a/postgres/parser/geo/geomfn/de9im.go b/postgres/parser/geo/geomfn/de9im.go deleted file mode 100644 index b6ea016f84..0000000000 --- a/postgres/parser/geo/geomfn/de9im.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// Relate returns the DE-9IM relation between A and B. -func Relate(a geo.Geometry, b geo.Geometry) (string, error) { - if a.SRID() != b.SRID() { - return "", geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - return geos.Relate(a.EWKB(), b.EWKB()) -} - -// RelatePattern returns whether the DE-9IM relation between A and B matches. -func RelatePattern(a geo.Geometry, b geo.Geometry, pattern string) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - return geos.RelatePattern(a.EWKB(), b.EWKB(), pattern) -} - -// MatchesDE9IM checks whether the given DE-9IM relation matches the DE-91M pattern. -// Assumes the relation has been computed, and such has no 'T' and '*' characters. -// See: https://en.wikipedia.org/wiki/DE-9IM. -func MatchesDE9IM(relation string, pattern string) (bool, error) { - if len(relation) != 9 { - return false, errors.Newf("relation %q should be of length 9", relation) - } - if len(pattern) != 9 { - return false, errors.Newf("pattern %q should be of length 9", pattern) - } - for i := 0; i < len(relation); i++ { - matches, err := relationByteMatchesPatternByte(relation[i], pattern[i]) - if err != nil { - return false, err - } - if !matches { - return false, nil - } - } - return true, nil -} - -// relationByteMatchesPatternByte matches a single byte of a DE-9IM relation -// against the DE-9IM pattern. -// Pattern matches are as follows: -// * '*': allow anything. -// * '0' / '1' / '2': match exactly. -// * 't'/'T': allow only if the relation is true. This means the relation must be -// '0' (point), '1' (line) or '2' (area) - which is the dimensionality of the -// intersection. -// * 'f'/'F': allow only if relation is also false, which is of the form 'f'/'F'. -func relationByteMatchesPatternByte(r byte, p byte) (bool, error) { - switch geo.ToLowerSingleByte(p) { - case '*': - return true, nil - case 't': - if r < '0' || r > '2' { - return false, nil - } - case 'f': - if geo.ToLowerSingleByte(r) != 'f' { - return false, nil - } - case '0', '1', '2': - return r == p, nil - default: - return false, errors.Newf("unrecognized pattern character: %s", string(p)) - } - return true, nil -} diff --git a/postgres/parser/geo/geomfn/distance.go b/postgres/parser/geo/geomfn/distance.go deleted file mode 100644 index fe3ab2cd4c..0000000000 --- a/postgres/parser/geo/geomfn/distance.go +++ /dev/null @@ -1,726 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - "github.com/twpayne/go-geom/xy/lineintersector" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geodist" -) - -// geometricalObjectsOrder allows us to preserve the order of geometrical objects to -// match the start and endpoint in the shortest and longest LineString. -type geometricalObjectsOrder int - -const ( - // geometricalObjectsFlipped represents that the given order of two geometrical - // objects has been flipped. - geometricalObjectsFlipped geometricalObjectsOrder = -1 - // geometricalObjectsNotFlipped represents that the given order of two geometrical - // objects has not been flipped. - geometricalObjectsNotFlipped geometricalObjectsOrder = 1 -) - -// MinDistance returns the minimum distance between geometries A and B. -// This returns a geo.EmptyGeometryError if either A or B is EMPTY. -func MinDistance(a geo.Geometry, b geo.Geometry) (float64, error) { - if a.SRID() != b.SRID() { - return 0, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - return minDistanceInternal(a, b, 0, geo.EmptyBehaviorOmit, geo.FnInclusive) -} - -// MaxDistance returns the maximum distance across every pair of points comprising -// geometries A and B. -func MaxDistance(a geo.Geometry, b geo.Geometry) (float64, error) { - if a.SRID() != b.SRID() { - return 0, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - return maxDistanceInternal(a, b, math.MaxFloat64, geo.EmptyBehaviorOmit, geo.FnInclusive) -} - -// DWithin determines if any part of geometry A is within D units of geometry B. -// If exclusive, DWithin is equivalent to Distance(a, b) < d. Otherwise, DWithin -// is equivalent to Distance(a, b) <= d. -func DWithin( - a geo.Geometry, b geo.Geometry, d float64, exclusivity geo.FnExclusivity, -) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if d < 0 { - return false, errors.Newf("dwithin distance cannot be less than zero") - } - if !a.CartesianBoundingBox().Buffer(d, d).Intersects(b.CartesianBoundingBox()) { - return false, nil - } - dist, err := minDistanceInternal(a, b, d, geo.EmptyBehaviorError, exclusivity) - if err != nil { - // In case of any empty geometries return false. - if geo.IsEmptyGeometryError(err) { - return false, nil - } - return false, err - } - if exclusivity == geo.FnExclusive { - return dist < d, nil - } - return dist <= d, nil -} - -// DFullyWithin determines whether the maximum distance across every pair of -// points comprising geometries A and B is within D units. If exclusive, -// DFullyWithin is equivalent to MaxDistance(a, b) < d. Otherwise, DFullyWithin -// is equivalent to MaxDistance(a, b) <= d. -func DFullyWithin( - a geo.Geometry, b geo.Geometry, d float64, exclusivity geo.FnExclusivity, -) (bool, error) { - if a.SRID() != b.SRID() { - return false, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - if d < 0 { - return false, errors.Newf("dwithin distance cannot be less than zero") - } - if !a.CartesianBoundingBox().Buffer(d, d).Covers(b.CartesianBoundingBox()) { - return false, nil - } - dist, err := maxDistanceInternal(a, b, d, geo.EmptyBehaviorError, exclusivity) - if err != nil { - // In case of any empty geometries return false. - if geo.IsEmptyGeometryError(err) { - return false, nil - } - return false, err - } - if exclusivity == geo.FnExclusive { - return dist < d, nil - } - return dist <= d, nil -} - -// LongestLineString returns the LineString corresponds to maximum distance across -// every pair of points comprising geometries A and B. -func LongestLineString(a geo.Geometry, b geo.Geometry) (geo.Geometry, error) { - if a.SRID() != b.SRID() { - return geo.Geometry{}, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - u := newGeomMaxDistanceUpdater(math.MaxFloat64, geo.FnInclusive) - return distanceLineStringInternal(a, b, u, geo.EmptyBehaviorOmit) -} - -// ShortestLineString returns the LineString corresponds to minimum distance across -// every pair of points comprising geometries A and B. -func ShortestLineString(a geo.Geometry, b geo.Geometry) (geo.Geometry, error) { - if a.SRID() != b.SRID() { - return geo.Geometry{}, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - u := newGeomMinDistanceUpdater(0 /*stopAfter */, geo.FnInclusive) - return distanceLineStringInternal(a, b, u, geo.EmptyBehaviorOmit) -} - -// distanceLineStringInternal calculates the LineString between two geometries using -// the DistanceCalculator operator. -// If there are any EMPTY Geometry objects, they will be ignored. It will return an -// EmptyGeometryError if A or B contains only EMPTY geometries, even if emptyBehavior -// is set to EmptyBehaviorOmit. -func distanceLineStringInternal( - a geo.Geometry, b geo.Geometry, u geodist.DistanceUpdater, emptyBehavior geo.EmptyBehavior, -) (geo.Geometry, error) { - c := &geomDistanceCalculator{updater: u, boundingBoxIntersects: a.CartesianBoundingBox().Intersects(b.CartesianBoundingBox())} - _, err := distanceInternal(a, b, c, emptyBehavior) - if err != nil { - return geo.Geometry{}, err - } - var coordA, coordB geom.Coord - switch u := u.(type) { - case *geomMaxDistanceUpdater: - coordA = u.coordA - coordB = u.coordB - case *geomMinDistanceUpdater: - coordA = u.coordA - coordB = u.coordB - default: - return geo.Geometry{}, errors.Newf("programmer error: unknown behavior") - } - lineString := geom.NewLineStringFlat(geom.XY, append(coordA, coordB...)).SetSRID(int(a.SRID())) - return geo.MakeGeometryFromGeomT(lineString) -} - -// maxDistanceInternal finds the maximum distance between two geometries. -// We can re-use the same algorithm as min-distance, allowing skips of checks that involve -// the interiors or intersections as those will always be less then the maximum min-distance. -func maxDistanceInternal( - a geo.Geometry, - b geo.Geometry, - stopAfter float64, - emptyBehavior geo.EmptyBehavior, - exclusivity geo.FnExclusivity, -) (float64, error) { - u := newGeomMaxDistanceUpdater(stopAfter, exclusivity) - c := &geomDistanceCalculator{updater: u, boundingBoxIntersects: a.CartesianBoundingBox().Intersects(b.CartesianBoundingBox())} - return distanceInternal(a, b, c, emptyBehavior) -} - -// minDistanceInternal finds the minimum distance between two geometries. -// This implementation is done in-house, as compared to using GEOS. -func minDistanceInternal( - a geo.Geometry, - b geo.Geometry, - stopAfter float64, - emptyBehavior geo.EmptyBehavior, - exclusivity geo.FnExclusivity, -) (float64, error) { - u := newGeomMinDistanceUpdater(stopAfter, exclusivity) - c := &geomDistanceCalculator{updater: u, boundingBoxIntersects: a.CartesianBoundingBox().Intersects(b.CartesianBoundingBox())} - return distanceInternal(a, b, c, emptyBehavior) -} - -// distanceInternal calculates the distance between two geometries using -// the DistanceCalculator operator. -// If there are any EMPTY Geometry objects, they will be ignored. It will return an -// EmptyGeometryError if A or B contains only EMPTY geometries, even if emptyBehavior -// is set to EmptyBehaviorOmit. -func distanceInternal( - a geo.Geometry, b geo.Geometry, c geodist.DistanceCalculator, emptyBehavior geo.EmptyBehavior, -) (float64, error) { - // If either side has no geoms, then we error out regardless of emptyBehavior. - if a.Empty() || b.Empty() { - return 0, geo.NewEmptyGeometryError() - } - - aGeomT, err := a.AsGeomT() - if err != nil { - return 0, err - } - bGeomT, err := b.AsGeomT() - if err != nil { - return 0, err - } - - // If we early exit, we have to check empty behavior upfront to return - // the appropriate error message. - // This matches PostGIS's behavior for DWithin, which is always false - // if at least one element is empty. - if emptyBehavior == geo.EmptyBehaviorError && - (geo.GeomTContainsEmpty(aGeomT) || geo.GeomTContainsEmpty(bGeomT)) { - return 0, geo.NewEmptyGeometryError() - } - - aIt := geo.NewGeomTIterator(aGeomT, emptyBehavior) - aGeom, aNext, aErr := aIt.Next() - if aErr != nil { - return 0, err - } - for aNext { - aGeodist, err := geomToGeodist(aGeom) - if err != nil { - return 0, err - } - - bIt := geo.NewGeomTIterator(bGeomT, emptyBehavior) - bGeom, bNext, bErr := bIt.Next() - if bErr != nil { - return 0, err - } - for bNext { - bGeodist, err := geomToGeodist(bGeom) - if err != nil { - return 0, err - } - earlyExit, err := geodist.ShapeDistance(c, aGeodist, bGeodist) - if err != nil { - return 0, err - } - if earlyExit { - return c.DistanceUpdater().Distance(), nil - } - - bGeom, bNext, bErr = bIt.Next() - if bErr != nil { - return 0, err - } - } - - aGeom, aNext, aErr = aIt.Next() - if aErr != nil { - return 0, err - } - } - return c.DistanceUpdater().Distance(), nil -} - -// geomToGeodist converts a given geom object to a geodist shape. -func geomToGeodist(g geom.T) (geodist.Shape, error) { - switch g := g.(type) { - case *geom.Point: - return &geodist.Point{GeomPoint: g.Coords()}, nil - case *geom.LineString: - return &geomGeodistLineString{LineString: g}, nil - case *geom.Polygon: - return &geomGeodistPolygon{Polygon: g}, nil - } - return nil, errors.Newf("could not find shape: %T", g) -} - -// geomGeodistLineString implements geodist.LineString. -type geomGeodistLineString struct { - *geom.LineString -} - -var _ geodist.LineString = (*geomGeodistLineString)(nil) - -// IsShape implements the geodist.LineString interface. -func (*geomGeodistLineString) IsShape() {} - -// LineString implements the geodist.LineString interface. -func (*geomGeodistLineString) IsLineString() {} - -// Edge implements the geodist.LineString interface. -func (g *geomGeodistLineString) Edge(i int) geodist.Edge { - return geodist.Edge{ - V0: geodist.Point{GeomPoint: g.LineString.Coord(i)}, - V1: geodist.Point{GeomPoint: g.LineString.Coord(i + 1)}, - } -} - -// NumEdges implements the geodist.LineString interface. -func (g *geomGeodistLineString) NumEdges() int { - return g.LineString.NumCoords() - 1 -} - -// Vertex implements the geodist.LineString interface. -func (g *geomGeodistLineString) Vertex(i int) geodist.Point { - return geodist.Point{GeomPoint: g.LineString.Coord(i)} -} - -// NumVertexes implements the geodist.LineString interface. -func (g *geomGeodistLineString) NumVertexes() int { - return g.LineString.NumCoords() -} - -// geomGeodistLinearRing implements geodist.LinearRing. -type geomGeodistLinearRing struct { - *geom.LinearRing -} - -var _ geodist.LinearRing = (*geomGeodistLinearRing)(nil) - -// IsShape implements the geodist.LinearRing interface. -func (*geomGeodistLinearRing) IsShape() {} - -// LinearRing implements the geodist.LinearRing interface. -func (*geomGeodistLinearRing) IsLinearRing() {} - -// Edge implements the geodist.LinearRing interface. -func (g *geomGeodistLinearRing) Edge(i int) geodist.Edge { - return geodist.Edge{ - V0: geodist.Point{GeomPoint: g.LinearRing.Coord(i)}, - V1: geodist.Point{GeomPoint: g.LinearRing.Coord(i + 1)}, - } -} - -// NumEdges implements the geodist.LinearRing interface. -func (g *geomGeodistLinearRing) NumEdges() int { - return g.LinearRing.NumCoords() - 1 -} - -// Vertex implements the geodist.LinearRing interface. -func (g *geomGeodistLinearRing) Vertex(i int) geodist.Point { - return geodist.Point{GeomPoint: g.LinearRing.Coord(i)} -} - -// NumVertexes implements the geodist.LinearRing interface. -func (g *geomGeodistLinearRing) NumVertexes() int { - return g.LinearRing.NumCoords() -} - -// geomGeodistPolygon implements geodist.Polygon. -type geomGeodistPolygon struct { - *geom.Polygon -} - -var _ geodist.Polygon = (*geomGeodistPolygon)(nil) - -// IsShape implements the geodist.Polygon interface. -func (*geomGeodistPolygon) IsShape() {} - -// Polygon implements the geodist.Polygon interface. -func (*geomGeodistPolygon) IsPolygon() {} - -// LinearRing implements the geodist.Polygon interface. -func (g *geomGeodistPolygon) LinearRing(i int) geodist.LinearRing { - return &geomGeodistLinearRing{LinearRing: g.Polygon.LinearRing(i)} -} - -// NumLinearRings implements the geodist.Polygon interface. -func (g *geomGeodistPolygon) NumLinearRings() int { - return g.Polygon.NumLinearRings() -} - -// geomGeodistEdgeCrosser implements geodist.EdgeCrosser. -type geomGeodistEdgeCrosser struct { - strategy lineintersector.Strategy - edgeV0 geom.Coord - edgeV1 geom.Coord - nextEdgeV0 geom.Coord -} - -var _ geodist.EdgeCrosser = (*geomGeodistEdgeCrosser)(nil) - -// ChainCrossing implements geodist.EdgeCrosser. -func (c *geomGeodistEdgeCrosser) ChainCrossing(p geodist.Point) (bool, geodist.Point) { - nextEdgeV1 := p.GeomPoint - result := lineintersector.LineIntersectsLine( - c.strategy, - c.edgeV0, - c.edgeV1, - c.nextEdgeV0, - nextEdgeV1, - ) - c.nextEdgeV0 = nextEdgeV1 - if result.HasIntersection() { - return true, geodist.Point{GeomPoint: result.Intersection()[0]} - } - return false, geodist.Point{} -} - -// geomMinDistanceUpdater finds the minimum distance using geom calculations. -// And preserve the line's endpoints as geom.Coord which corresponds to minimum -// distance. If inclusive, methods will return early if it finds a minimum -// distance <= stopAfter. Otherwise, methods will return early if it finds a -// minimum distance < stopAfter. -type geomMinDistanceUpdater struct { - currentValue float64 - stopAfter float64 - exclusivity geo.FnExclusivity - // coordA represents the first vertex of the edge that holds the maximum distance. - coordA geom.Coord - // coordB represents the second vertex of the edge that holds the maximum distance. - coordB geom.Coord - - geometricalObjOrder geometricalObjectsOrder -} - -var _ geodist.DistanceUpdater = (*geomMinDistanceUpdater)(nil) - -// newGeomMinDistanceUpdater returns a new geomMinDistanceUpdater with the -// correct arguments set up. -func newGeomMinDistanceUpdater( - stopAfter float64, exclusivity geo.FnExclusivity, -) *geomMinDistanceUpdater { - return &geomMinDistanceUpdater{ - currentValue: math.MaxFloat64, - stopAfter: stopAfter, - exclusivity: exclusivity, - coordA: nil, - coordB: nil, - geometricalObjOrder: geometricalObjectsNotFlipped, - } -} - -// Distance implements the geodist.DistanceUpdater interface. -func (u *geomMinDistanceUpdater) Distance() float64 { - return u.currentValue -} - -// Update implements the geodist.DistanceUpdater interface. -func (u *geomMinDistanceUpdater) Update(aPoint geodist.Point, bPoint geodist.Point) bool { - a := aPoint.GeomPoint - b := bPoint.GeomPoint - - dist := coordNorm(coordSub(a, b)) - if dist < u.currentValue { - u.currentValue = dist - if u.geometricalObjOrder == geometricalObjectsFlipped { - u.coordA = b - u.coordB = a - } else { - u.coordA = a - u.coordB = b - } - if u.exclusivity == geo.FnExclusive { - return dist < u.stopAfter - } - return dist <= u.stopAfter - } - return false -} - -// OnIntersects implements the geodist.DistanceUpdater interface. -func (u *geomMinDistanceUpdater) OnIntersects(p geodist.Point) bool { - u.coordA = p.GeomPoint - u.coordB = p.GeomPoint - u.currentValue = 0 - return true -} - -// IsMaxDistance implements the geodist.DistanceUpdater interface. -func (u *geomMinDistanceUpdater) IsMaxDistance() bool { - return false -} - -// FlipGeometries implements the geodist.DistanceUpdater interface. -func (u *geomMinDistanceUpdater) FlipGeometries() { - u.geometricalObjOrder = -u.geometricalObjOrder -} - -// geomMaxDistanceUpdater finds the maximum distance using geom calculations. -// And preserve the line's endpoints as geom.Coord which corresponds to maximum -// distance. If exclusive, methods will return early if it finds that -// distance >= stopAfter. Otherwise, methods will return early if distance > -// stopAfter. -type geomMaxDistanceUpdater struct { - currentValue float64 - stopAfter float64 - exclusivity geo.FnExclusivity - - // coordA represents the first vertex of the edge that holds the maximum distance. - coordA geom.Coord - // coordB represents the second vertex of the edge that holds the maximum distance. - coordB geom.Coord - - geometricalObjOrder geometricalObjectsOrder -} - -var _ geodist.DistanceUpdater = (*geomMaxDistanceUpdater)(nil) - -// newGeomMaxDistanceUpdater returns a new geomMaxDistanceUpdater with the -// correct arguments set up. currentValue is initially populated with least -// possible value instead of 0 because there may be the case where maximum -// distance is 0 and we may require to find the line for 0 maximum distance. -func newGeomMaxDistanceUpdater( - stopAfter float64, exclusivity geo.FnExclusivity, -) *geomMaxDistanceUpdater { - return &geomMaxDistanceUpdater{ - currentValue: -math.MaxFloat64, - stopAfter: stopAfter, - exclusivity: exclusivity, - coordA: nil, - coordB: nil, - geometricalObjOrder: geometricalObjectsNotFlipped, - } -} - -// Distance implements the geodist.DistanceUpdater interface. -func (u *geomMaxDistanceUpdater) Distance() float64 { - return u.currentValue -} - -// Update implements the geodist.DistanceUpdater interface. -func (u *geomMaxDistanceUpdater) Update(aPoint geodist.Point, bPoint geodist.Point) bool { - a := aPoint.GeomPoint - b := bPoint.GeomPoint - - dist := coordNorm(coordSub(a, b)) - if dist > u.currentValue { - u.currentValue = dist - if u.geometricalObjOrder == geometricalObjectsFlipped { - u.coordA = b - u.coordB = a - } else { - u.coordA = a - u.coordB = b - } - if u.exclusivity == geo.FnExclusive { - return dist >= u.stopAfter - } - return dist > u.stopAfter - } - return false -} - -// OnIntersects implements the geodist.DistanceUpdater interface. -func (u *geomMaxDistanceUpdater) OnIntersects(p geodist.Point) bool { - return false -} - -// IsMaxDistance implements the geodist.DistanceUpdater interface. -func (u *geomMaxDistanceUpdater) IsMaxDistance() bool { - return true -} - -// FlipGeometries implements the geodist.DistanceUpdater interface. -func (u *geomMaxDistanceUpdater) FlipGeometries() { - u.geometricalObjOrder = -u.geometricalObjOrder -} - -// geomDistanceCalculator implements geodist.DistanceCalculator -type geomDistanceCalculator struct { - updater geodist.DistanceUpdater - boundingBoxIntersects bool -} - -var _ geodist.DistanceCalculator = (*geomDistanceCalculator)(nil) - -// DistanceUpdater implements geodist.DistanceCalculator. -func (c *geomDistanceCalculator) DistanceUpdater() geodist.DistanceUpdater { - return c.updater -} - -// BoundingBoxIntersects implements geodist.DistanceCalculator. -func (c *geomDistanceCalculator) BoundingBoxIntersects() bool { - return c.boundingBoxIntersects -} - -// NewEdgeCrosser implements geodist.DistanceCalculator. -func (c *geomDistanceCalculator) NewEdgeCrosser( - edge geodist.Edge, startPoint geodist.Point, -) geodist.EdgeCrosser { - return &geomGeodistEdgeCrosser{ - strategy: &lineintersector.NonRobustLineIntersector{}, - edgeV0: edge.V0.GeomPoint, - edgeV1: edge.V1.GeomPoint, - nextEdgeV0: startPoint.GeomPoint, - } -} - -// side corresponds to the side in which a point is relative to a line. -type pointSide int - -const ( - pointSideLeft pointSide = -1 - pointSideOn pointSide = 0 - pointSideRight pointSide = 1 -) - -// findPointSide finds which side a point is relative to the infinite line -// given by the edge. -// Note this side is relative to the orientation of the line. -func findPointSide(p geom.Coord, eV0 geom.Coord, eV1 geom.Coord) pointSide { - // This is the equivalent of using the point-gradient formula - // and determining the sign, i.e. the sign of - // d = (x-x1)(y2-y1) - (y-y1)(x2-x1) - // where (x1,y1) and (x2,y2) is the edge and (x,y) is the point - sign := (p.X()-eV0.X())*(eV1.Y()-eV0.Y()) - (eV1.X()-eV0.X())*(p.Y()-eV0.Y()) - switch { - case sign == 0: - return pointSideOn - case sign > 0: - return pointSideRight - default: - return pointSideLeft - } -} - -// PointInLinearRing implements geodist.DistanceCalculator. -func (c *geomDistanceCalculator) PointInLinearRing( - point geodist.Point, polygon geodist.LinearRing, -) bool { - // This is done using the winding number algorithm, also known as the - // "non-zero rule". - // See: https://en.wikipedia.org/wiki/Point_in_polygon for intro. - // See: http://geomalgorithms.com/a03-_inclusion.html for algorithm. - // See also: https://en.wikipedia.org/wiki/Winding_number - // See also: https://en.wikipedia.org/wiki/Nonzero-rule - windingNumber := 0 - p := point.GeomPoint - for edgeIdx := 0; edgeIdx < polygon.NumEdges(); edgeIdx++ { - e := polygon.Edge(edgeIdx) - eV0 := e.V0.GeomPoint - eV1 := e.V1.GeomPoint - // Same vertex; none of these checks will pass. - if coordEqual(eV0, eV1) { - continue - } - yMin := math.Min(eV0.Y(), eV1.Y()) - yMax := math.Max(eV0.Y(), eV1.Y()) - // If the edge isn't on the same level as Y, this edge isn't worth considering. - if p.Y() > yMax || p.Y() < yMin { - continue - } - side := findPointSide(p, eV0, eV1) - // If the point is on the line if the edge was infinite, and the point is within the bounds - // of the line segment denoted by the edge, there is a covering. - if side == pointSideOn && - ((eV0.X() <= p.X() && p.X() < eV1.X()) || (eV1.X() <= p.X() && p.X() < eV0.X()) || - (eV0.Y() <= p.Y() && p.Y() < eV1.Y()) || (eV1.Y() <= p.Y() && p.Y() < eV0.Y())) { - return true - } - // If the point is left of the segment and the line is rising - // we have a circle going CCW, so increment. - // Note we only compare [start, end) as we do not want to double count points - // which are on the same X / Y axis as an edge vertex. - if side == pointSideLeft && eV0.Y() <= p.Y() && p.Y() < eV1.Y() { - windingNumber++ - } - // If the line is to the right of the segment and the - // line is falling, we a have a circle going CW so decrement. - // Note we only compare [start, end) as we do not want to double count points - // which are on the same X / Y axis as an edge vertex. - if side == pointSideRight && eV1.Y() <= p.Y() && p.Y() < eV0.Y() { - windingNumber-- - } - } - return windingNumber != 0 -} - -// ClosestPointToEdge implements geodist.DistanceCalculator. -func (c *geomDistanceCalculator) ClosestPointToEdge( - e geodist.Edge, p geodist.Point, -) (geodist.Point, bool) { - // Edge is a single point. Closest point must be any edge vertex. - if coordEqual(e.V0.GeomPoint, e.V1.GeomPoint) { - return e.V0, coordEqual(e.V0.GeomPoint, p.GeomPoint) - } - - // From http://www.faqs.org/faqs/graphics/algorithms-faq/, section 1.02 - // - // Let the point be C (Cx,Cy) and the line be AB (Ax,Ay) to (Bx,By). - // Let P be the point of perpendicular projection of C on AB. The parameter - // r, which indicates P's position along AB, is computed by the dot product - // of AC and AB divided by the square of the length of AB: - // - // (1) AC dot AB - // r = --------- - // ||AB||^2 - // - // r has the following meaning: - // - // r=0 P = A - // r=1 P = B - // r<0 P is on the backward extension of AB - // r>1 P is on the forward extension of AB - // 0 1 { - return p, false - } - return geodist.Point{GeomPoint: coordAdd(e.V0.GeomPoint, coordMul(ab, r))}, true -} diff --git a/postgres/parser/geo/geomfn/envelope.go b/postgres/parser/geo/geomfn/envelope.go deleted file mode 100644 index f04da46e81..0000000000 --- a/postgres/parser/geo/geomfn/envelope.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import "github.com/dolthub/doltgresql/postgres/parser/geo" - -// Envelope forms an envelope (compliant with the OGC spec) of the given Geometry. -// It uses the bounding box to return a Polygon, but can return a Point or -// Line if the bounding box is degenerate and not a box. -func Envelope(g geo.Geometry) (geo.Geometry, error) { - if g.Empty() { - return g, nil - } - return geo.MakeGeometryFromGeomT(g.CartesianBoundingBox().ToGeomT(g.SRID())) -} diff --git a/postgres/parser/geo/geomfn/flip_coordinates.go b/postgres/parser/geo/geomfn/flip_coordinates.go deleted file mode 100644 index 9b4a269b12..0000000000 --- a/postgres/parser/geo/geomfn/flip_coordinates.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// FlipCoordinates returns a modified g whose X, Y coordinates are flipped. -func FlipCoordinates(g geo.Geometry) (geo.Geometry, error) { - if g.Empty() { - return g, nil - } - - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - newT, err := applyOnCoordsForGeomT(t, func(l geom.Layout, dst, src []float64) error { - dst[0], dst[1] = src[1], src[0] - if l.ZIndex() != -1 { - dst[l.ZIndex()] = src[l.ZIndex()] - } - if l.MIndex() != -1 { - dst[l.MIndex()] = src[l.MIndex()] - } - return nil - }) - if err != nil { - return geo.Geometry{}, err - } - - return geo.MakeGeometryFromGeomT(newT) -} diff --git a/postgres/parser/geo/geomfn/force_layout.go b/postgres/parser/geo/geomfn/force_layout.go deleted file mode 100644 index d36c3b2e59..0000000000 --- a/postgres/parser/geo/geomfn/force_layout.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// ForceLayout forces a geometry into the given layout. -// If dimensions are added, 0 coordinates are padded to them. -func ForceLayout(g geo.Geometry, layout geom.Layout) (geo.Geometry, error) { - geomT, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - retGeomT, err := forceLayout(geomT, layout) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(retGeomT) -} - -// forceLayout forces a geom.T into the given layout. -func forceLayout(t geom.T, layout geom.Layout) (geom.T, error) { - if t.Layout() == layout { - return t, nil - } - switch t := t.(type) { - case *geom.GeometryCollection: - ret := geom.NewGeometryCollection().SetSRID(t.SRID()) - for i := 0; i < t.NumGeoms(); i++ { - toPush, err := forceLayout(t.Geom(i), layout) - if err != nil { - return nil, err - } - if err := ret.Push(toPush); err != nil { - return nil, err - } - } - return ret, nil - case *geom.Point: - return geom.NewPointFlat(layout, forceFlatCoordsLayout(t, layout)).SetSRID(t.SRID()), nil - case *geom.LineString: - return geom.NewLineStringFlat(layout, forceFlatCoordsLayout(t, layout)).SetSRID(t.SRID()), nil - case *geom.Polygon: - return geom.NewPolygonFlat( - layout, - forceFlatCoordsLayout(t, layout), - forceEnds(t.Ends(), t.Layout(), layout), - ).SetSRID(t.SRID()), nil - case *geom.MultiPoint: - return geom.NewMultiPointFlat( - layout, - forceFlatCoordsLayout(t, layout), - geom.NewMultiPointFlatOptionWithEnds(forceEnds(t.Ends(), t.Layout(), layout)), - ).SetSRID(t.SRID()), nil - case *geom.MultiLineString: - return geom.NewMultiLineStringFlat( - layout, - forceFlatCoordsLayout(t, layout), - forceEnds(t.Ends(), t.Layout(), layout), - ).SetSRID(t.SRID()), nil - case *geom.MultiPolygon: - endss := make([][]int, len(t.Endss())) - for i := range t.Endss() { - endss[i] = forceEnds(t.Endss()[i], t.Layout(), layout) - } - return geom.NewMultiPolygonFlat( - layout, - forceFlatCoordsLayout(t, layout), - endss, - ).SetSRID(t.SRID()), nil - default: - return nil, errors.Newf("unknown geom.T type: %T", t) - } -} - -// forceEnds forces the Endss layout of a geometry into the new layout. -func forceEnds(ends []int, oldLayout geom.Layout, newLayout geom.Layout) []int { - if oldLayout.Stride() == newLayout.Stride() { - return ends - } - newEnds := make([]int, len(ends)) - for i := range ends { - newEnds[i] = (ends[i] / oldLayout.Stride()) * newLayout.Stride() - } - return newEnds -} - -// forceFlatCoordsLayout forces the flatCoords layout of a geometry into the new layout. -func forceFlatCoordsLayout(t geom.T, layout geom.Layout) []float64 { - oldFlatCoords := t.FlatCoords() - newFlatCoords := make([]float64, (len(oldFlatCoords)/t.Stride())*layout.Stride()) - for coordIdx := 0; coordIdx < len(oldFlatCoords)/t.Stride(); coordIdx++ { - newFlatCoords[coordIdx*layout.Stride()] = oldFlatCoords[coordIdx*t.Stride()] - newFlatCoords[coordIdx*layout.Stride()+1] = oldFlatCoords[coordIdx*t.Stride()+1] - if layout.ZIndex() != -1 { - z := float64(0) - if t.Layout().ZIndex() != -1 { - z = oldFlatCoords[coordIdx*t.Stride()+t.Layout().ZIndex()] - } - newFlatCoords[coordIdx*layout.Stride()+layout.ZIndex()] = z - } - if layout.MIndex() != -1 { - m := float64(0) - if t.Layout().MIndex() != -1 { - m = oldFlatCoords[coordIdx*t.Stride()+t.Layout().MIndex()] - } - newFlatCoords[coordIdx*layout.Stride()+layout.MIndex()] = m - } - } - return newFlatCoords -} diff --git a/postgres/parser/geo/geomfn/geomfn.go b/postgres/parser/geo/geomfn/geomfn.go deleted file mode 100644 index d365d64430..0000000000 --- a/postgres/parser/geo/geomfn/geomfn.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// Package geomfn contains functions that are used for geometry-based builtins. -package geomfn - -import "github.com/twpayne/go-geom" - -// applyCoordFunc applies a function on src to copy onto dst. -// Both slices represent a single Coord within the FlatCoord array. -type applyCoordFunc func(l geom.Layout, dst []float64, src []float64) error - -// applyOnCoords applies the applyCoordFunc on each coordinate, returning -// a new array for the coordinates. -func applyOnCoords(flatCoords []float64, l geom.Layout, f applyCoordFunc) ([]float64, error) { - newCoords := make([]float64, len(flatCoords)) - for i := 0; i < len(flatCoords); i += l.Stride() { - if err := f(l, newCoords[i:i+l.Stride()], flatCoords[i:i+l.Stride()]); err != nil { - return nil, err - } - } - return newCoords, nil -} - -// applyOnCoordsForGeomT applies the applyCoordFunc on each coordinate in the geom.T, -// returning a copied over geom.T. -func applyOnCoordsForGeomT(g geom.T, f applyCoordFunc) (geom.T, error) { - if geomCollection, ok := g.(*geom.GeometryCollection); ok { - return applyOnCoordsForGeometryCollection(geomCollection, f) - } - - newCoords, err := applyOnCoords(g.FlatCoords(), g.Layout(), f) - if err != nil { - return nil, err - } - - switch t := g.(type) { - case *geom.Point: - g = geom.NewPointFlat(t.Layout(), newCoords).SetSRID(g.SRID()) - case *geom.LineString: - g = geom.NewLineStringFlat(t.Layout(), newCoords).SetSRID(g.SRID()) - case *geom.Polygon: - g = geom.NewPolygonFlat(t.Layout(), newCoords, t.Ends()).SetSRID(g.SRID()) - case *geom.MultiPoint: - g = geom.NewMultiPointFlat(t.Layout(), newCoords).SetSRID(g.SRID()) - case *geom.MultiLineString: - g = geom.NewMultiLineStringFlat(t.Layout(), newCoords, t.Ends()).SetSRID(g.SRID()) - case *geom.MultiPolygon: - g = geom.NewMultiPolygonFlat(t.Layout(), newCoords, t.Endss()).SetSRID(g.SRID()) - default: - return nil, geom.ErrUnsupportedType{Value: g} - } - - return g, nil -} - -// applyOnCoordsForGeometryCollection applies the applyCoordFunc on each coordinate -// inside a geometry collection, returning a copied over geom.T. -func applyOnCoordsForGeometryCollection( - geomCollection *geom.GeometryCollection, f applyCoordFunc, -) (*geom.GeometryCollection, error) { - res := geom.NewGeometryCollection() - for _, subG := range geomCollection.Geoms() { - subGeom, err := applyOnCoordsForGeomT(subG, f) - if err != nil { - return nil, err - } - - if err := res.Push(subGeom); err != nil { - return nil, err - } - } - return res, nil -} diff --git a/postgres/parser/geo/geomfn/linear_reference.go b/postgres/parser/geo/geomfn/linear_reference.go deleted file mode 100644 index 42a5e32a25..0000000000 --- a/postgres/parser/geo/geomfn/linear_reference.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - "github.com/twpayne/go-geom/encoding/ewkb" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// LineInterpolatePoints returns one or more points along the given -// LineString which are at an integral multiples of given fraction of -// LineString's total length. When repeat is set to false, it returns -// the first point. -func LineInterpolatePoints(g geo.Geometry, fraction float64, repeat bool) (geo.Geometry, error) { - if fraction < 0 || fraction > 1 { - return geo.Geometry{}, errors.Newf("fraction %f should be within [0 1] range", fraction) - } - geomRepr, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - switch geomRepr := geomRepr.(type) { - case *geom.LineString: - // In case fraction is greater than 0.5 or equal to 0 or repeat is false, - // then we will have only one interpolated point. - lengthOfLineString := geomRepr.Length() - if repeat && fraction <= 0.5 && fraction != 0 { - numberOfInterpolatedPoints := int(1 / fraction) - interpolatedPoints := geom.NewMultiPoint(geom.XY).SetSRID(geomRepr.SRID()) - for pointInserted := 1; pointInserted <= numberOfInterpolatedPoints; pointInserted++ { - pointEWKB, err := geos.InterpolateLine(g.EWKB(), float64(pointInserted)*fraction*lengthOfLineString) - if err != nil { - return geo.Geometry{}, err - } - point, err := ewkb.Unmarshal(pointEWKB) - if err != nil { - return geo.Geometry{}, err - } - err = interpolatedPoints.Push(point.(*geom.Point)) - if err != nil { - return geo.Geometry{}, err - } - } - return geo.MakeGeometryFromGeomT(interpolatedPoints) - } - interpolatedPointEWKB, err := geos.InterpolateLine(g.EWKB(), fraction*lengthOfLineString) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(interpolatedPointEWKB) - default: - return geo.Geometry{}, errors.Newf("geometry %s should be LineString", g.ShapeType()) - } -} diff --git a/postgres/parser/geo/geomfn/linestring.go b/postgres/parser/geo/geomfn/linestring.go deleted file mode 100644 index f524eeb429..0000000000 --- a/postgres/parser/geo/geomfn/linestring.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// LineStringFromMultiPoint generates a linestring from a multipoint. -func LineStringFromMultiPoint(g geo.Geometry) (geo.Geometry, error) { - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - mp, ok := t.(*geom.MultiPoint) - if !ok { - return geo.Geometry{}, errors.Wrap(geom.ErrUnsupportedType{Value: t}, - "geometry must be a MultiPoint") - } - if mp.NumPoints() == 1 { - return geo.Geometry{}, errors.Newf("a LineString must have at least 2 points") - } - lineString := geom.NewLineString(mp.Layout()).SetSRID(mp.SRID()) - lineString, err = lineString.SetCoords(mp.Coords()) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(lineString) -} - -// LineMerge merges multilinestring constituents. -func LineMerge(g geo.Geometry) (geo.Geometry, error) { - // Mirrors PostGIS behavior - if g.Empty() { - return g, nil - } - ret, err := geos.LineMerge(g.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(ret) -} - -// AddPoint adds a point to a LineString at the given 0-based index. -1 appends. -func AddPoint(lineString geo.Geometry, index int, point geo.Geometry) (geo.Geometry, error) { - g, err := lineString.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - lineStringG, ok := g.(*geom.LineString) - if !ok { - e := geom.ErrUnsupportedType{Value: g} - return geo.Geometry{}, errors.Wrap(e, "geometry to be modified must be a LineString") - } - - g, err = point.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - pointG, ok := g.(*geom.Point) - if !ok { - e := geom.ErrUnsupportedType{Value: g} - return geo.Geometry{}, errors.Wrapf(e, "invalid geometry used to add a Point to a LineString") - } - - g, err = addPoint(lineStringG, index, pointG) - if err != nil { - return geo.Geometry{}, err - } - - return geo.MakeGeometryFromGeomT(g) -} - -func addPoint(lineString *geom.LineString, index int, point *geom.Point) (*geom.LineString, error) { - if lineString.Layout() != point.Layout() { - return nil, geom.ErrLayoutMismatch{Got: point.Layout(), Want: lineString.Layout()} - } - if point.Empty() { - point = geom.NewPointFlat(point.Layout(), make([]float64, point.Stride())) - } - - coords := lineString.Coords() - - if index > len(coords) { - return nil, errors.Newf("index %d out of range of LineString with %d coordinates", - index, len(coords)) - } else if index == -1 { - index = len(coords) - } else if index < 0 { - return nil, errors.Newf("invalid index %v", index) - } - - // Shift the slice right by one element, then replace the element at the index, to avoid - // allocating an additional slice. - coords = append(coords, geom.Coord{}) - copy(coords[index+1:], coords[index:]) - coords[index] = point.Coords() - - return lineString.SetCoords(coords) -} - -// SetPoint sets the point at the given index of lineString; index is 0-based. -func SetPoint(lineString geo.Geometry, index int, point geo.Geometry) (geo.Geometry, error) { - g, err := lineString.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - lineStringG, ok := g.(*geom.LineString) - if !ok { - e := geom.ErrUnsupportedType{Value: g} - return geo.Geometry{}, errors.Wrap(e, "geometry to be modified must be a LineString") - } - - g, err = point.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - pointG, ok := g.(*geom.Point) - if !ok { - e := geom.ErrUnsupportedType{Value: g} - return geo.Geometry{}, errors.Wrapf(e, "invalid geometry used to replace a Point on a LineString") - } - - g, err = setPoint(lineStringG, index, pointG) - if err != nil { - return geo.Geometry{}, err - } - - return geo.MakeGeometryFromGeomT(g) -} - -func setPoint(lineString *geom.LineString, index int, point *geom.Point) (*geom.LineString, error) { - if lineString.Layout() != point.Layout() { - return nil, geom.ErrLayoutMismatch{Got: point.Layout(), Want: lineString.Layout()} - } - if point.Empty() { - point = geom.NewPointFlat(point.Layout(), make([]float64, point.Stride())) - } - - coords := lineString.Coords() - hasNegIndex := index < 0 - - if index >= len(coords) || (hasNegIndex && index*-1 > len(coords)) { - return nil, errors.Newf("index %d out of range of LineString with %d coordinates", index, len(coords)) - } - - if hasNegIndex { - index = len(coords) + index - } - - coords[index].Set(point.Coords()) - - return lineString.SetCoords(coords) -} - -// RemovePoint removes the point at the given index of lineString; index is 0-based. -func RemovePoint(lineString geo.Geometry, index int) (geo.Geometry, error) { - g, err := lineString.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - lineStringG, ok := g.(*geom.LineString) - if !ok { - e := geom.ErrUnsupportedType{Value: g} - return geo.Geometry{}, errors.Wrap(e, "geometry to be modified must be a LineString") - } - - if lineStringG.NumCoords() == 2 { - return geo.Geometry{}, errors.Newf("cannot remove a point from a LineString with only two Points") - } - - g, err = removePoint(lineStringG, index) - if err != nil { - return geo.Geometry{}, err - } - - return geo.MakeGeometryFromGeomT(g) -} - -func removePoint(lineString *geom.LineString, index int) (*geom.LineString, error) { - coords := lineString.Coords() - - if index >= len(coords) || index < 0 { - return nil, errors.Newf("index %d out of range of LineString with %d coordinates", index, len(coords)) - } - - coords = append(coords[:index], coords[index+1:]...) - - return lineString.SetCoords(coords) -} diff --git a/postgres/parser/geo/geomfn/make_geometry.go b/postgres/parser/geo/geomfn/make_geometry.go deleted file mode 100644 index fdadfa2375..0000000000 --- a/postgres/parser/geo/geomfn/make_geometry.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// MakePolygon creates a Polygon geometry from linestring and optional inner linestrings. -// Returns errors if geometries are not linestrings. -func MakePolygon(outer geo.Geometry, interior ...geo.Geometry) (geo.Geometry, error) { - layout := geom.XY - outerGeomT, err := outer.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - outerRing, ok := outerGeomT.(*geom.LineString) - if !ok { - return geo.Geometry{}, errors.Newf("argument must be LINESTRING geometries") - } - srid := outerRing.SRID() - coords := make([][]geom.Coord, len(interior)+1) - coords[0] = outerRing.Coords() - for i, g := range interior { - interiorRingGeomT, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - interiorRing, ok := interiorRingGeomT.(*geom.LineString) - if !ok { - return geo.Geometry{}, errors.Newf("argument must be LINESTRING geometries") - } - if interiorRing.SRID() != srid { - return geo.Geometry{}, errors.Newf("mixed SRIDs are not allowed") - } - coords[i+1] = interiorRing.Coords() - } - - polygon, err := geom.NewPolygon(layout).SetSRID(srid).SetCoords(coords) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(polygon) -} diff --git a/postgres/parser/geo/geomfn/orientation.go b/postgres/parser/geo/geomfn/orientation.go deleted file mode 100644 index 5dea2bd91f..0000000000 --- a/postgres/parser/geo/geomfn/orientation.go +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// Orientation defines an orientation of a shape. -type Orientation int - -const ( - // OrientationCW denotes a clockwise orientation. - OrientationCW Orientation = iota - // OrientationCCW denotes a counter-clockwise orientation - OrientationCCW -) - -// HasPolygonOrientation checks whether a given Geometry have polygons -// that matches the given Orientation. -// Non-Polygon objects -func HasPolygonOrientation(g geo.Geometry, o Orientation) (bool, error) { - t, err := g.AsGeomT() - if err != nil { - return false, err - } - return hasPolygonOrientation(t, o) -} - -func hasPolygonOrientation(g geom.T, o Orientation) (bool, error) { - switch g := g.(type) { - case *geom.Polygon: - for i := 0; i < g.NumLinearRings(); i++ { - isCCW := geo.IsLinearRingCCW(g.LinearRing(i)) - // Interior rings should be the reverse orientation of the exterior ring. - if i > 0 { - isCCW = !isCCW - } - switch o { - case OrientationCW: - if isCCW { - return false, nil - } - case OrientationCCW: - if !isCCW { - return false, nil - } - default: - return false, errors.Newf("unexpected orientation: %v", o) - } - } - return true, nil - case *geom.MultiPolygon: - for i := 0; i < g.NumPolygons(); i++ { - if ret, err := hasPolygonOrientation(g.Polygon(i), o); !ret || err != nil { - return ret, err - } - } - return true, nil - case *geom.GeometryCollection: - for i := 0; i < g.NumGeoms(); i++ { - if ret, err := hasPolygonOrientation(g.Geom(i), o); !ret || err != nil { - return ret, err - } - } - return true, nil - case *geom.Point, *geom.MultiPoint, *geom.LineString, *geom.MultiLineString: - return true, nil - default: - return false, errors.Newf("unhandled geometry type: %T", g) - } -} - -// ForcePolygonOrientation forces orientations within polygons -// to be oriented the prescribed way. -func ForcePolygonOrientation(g geo.Geometry, o Orientation) (geo.Geometry, error) { - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - if err := forcePolygonOrientation(t, o); err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(t) -} - -func forcePolygonOrientation(g geom.T, o Orientation) error { - switch g := g.(type) { - case *geom.Polygon: - for i := 0; i < g.NumLinearRings(); i++ { - isCCW := geo.IsLinearRingCCW(g.LinearRing(i)) - // Interior rings should be the reverse orientation of the exterior ring. - if i > 0 { - isCCW = !isCCW - } - reverse := false - switch o { - case OrientationCW: - if isCCW { - reverse = true - } - case OrientationCCW: - if !isCCW { - reverse = true - } - default: - return errors.Newf("unexpected orientation: %v", o) - } - - if reverse { - // Reverse coordinates from both ends. - // Do this by swapping up to the middle of the array of elements, which guarantees - // each end get swapped. This works for an odd number of elements as well as - // the middle element ends swapping with itself, which is ok. - coords := g.LinearRing(i).FlatCoords() - for cIdx := 0; cIdx < len(coords)/2; cIdx += g.Stride() { - for sIdx := 0; sIdx < g.Stride(); sIdx++ { - coords[cIdx+sIdx], coords[len(coords)-cIdx-g.Stride()+sIdx] = coords[len(coords)-cIdx-g.Stride()+sIdx], coords[cIdx+sIdx] - } - } - } - } - return nil - case *geom.MultiPolygon: - for i := 0; i < g.NumPolygons(); i++ { - if err := forcePolygonOrientation(g.Polygon(i), o); err != nil { - return err - } - } - return nil - case *geom.GeometryCollection: - for i := 0; i < g.NumGeoms(); i++ { - if err := forcePolygonOrientation(g.Geom(i), o); err != nil { - return err - } - } - return nil - case *geom.Point, *geom.MultiPoint, *geom.LineString, *geom.MultiLineString: - return nil - default: - return errors.Newf("unhandled geometry type: %T", g) - } -} diff --git a/postgres/parser/geo/geomfn/remove_repeated_points.go b/postgres/parser/geo/geomfn/remove_repeated_points.go deleted file mode 100644 index 4fef489473..0000000000 --- a/postgres/parser/geo/geomfn/remove_repeated_points.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// RemoveRepeatedPoints returns the geometry with repeated points removed. -func RemoveRepeatedPoints(g geo.Geometry, tolerance float64) (geo.Geometry, error) { - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - // Use the square of the tolerance to avoid taking the square root of distance results. - t, err = removeRepeatedPointsFromGeomT(t, tolerance*tolerance) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(t) -} - -func removeRepeatedPointsFromGeomT(t geom.T, tolerance2 float64) (geom.T, error) { - switch t := t.(type) { - case *geom.Point: - case *geom.LineString: - if coords, modified := removeRepeatedCoords(t.Layout(), t.Coords(), tolerance2, 2); modified { - return t.SetCoords(coords) - } - case *geom.Polygon: - if coords, modified := removeRepeatedCoords2(t.Layout(), t.Coords(), tolerance2, 4); modified { - return t.SetCoords(coords) - } - case *geom.MultiPoint: - if coords, modified := removeRepeatedCoords(t.Layout(), t.Coords(), tolerance2, 0); modified { - return t.SetCoords(coords) - } - case *geom.MultiLineString: - if coords, modified := removeRepeatedCoords2(t.Layout(), t.Coords(), tolerance2, 2); modified { - return t.SetCoords(coords) - } - case *geom.MultiPolygon: - if coords, modified := removeRepeatedCoords3(t.Layout(), t.Coords(), tolerance2, 4); modified { - return t.SetCoords(coords) - } - case *geom.GeometryCollection: - for _, g := range t.Geoms() { - if _, err := removeRepeatedPointsFromGeomT(g, tolerance2); err != nil { - return nil, err - } - } - default: - return nil, errors.AssertionFailedf("unknown geometry type: %T", t) - } - return t, nil -} - -func removeRepeatedCoords( - layout geom.Layout, coords []geom.Coord, tolerance2 float64, minCoords int, -) ([]geom.Coord, bool) { - modified := false - switch tolerance2 { - case 0: - for i := 1; i < len(coords) && len(coords) > minCoords; i++ { - if coords[i].Equal(layout, coords[i-1]) { - coords = append(coords[:i], coords[i+1:]...) - modified = true - i-- - } - } - default: - for i := 1; i < len(coords) && len(coords) > minCoords; i++ { - if coordMag2(coordSub(coords[i], coords[i-1])) <= tolerance2 { - coords = append(coords[:i], coords[i+1:]...) - modified = true - i-- - } - } - } - return coords, modified -} - -func removeRepeatedCoords2( - layout geom.Layout, coords2 [][]geom.Coord, tolerance2 float64, minCoords int, -) ([][]geom.Coord, bool) { - modified := false - for i, coords := range coords2 { - if c, m := removeRepeatedCoords(layout, coords, tolerance2, minCoords); m { - coords2[i] = c - modified = true - } - } - return coords2, modified -} - -func removeRepeatedCoords3( - layout geom.Layout, coords3 [][][]geom.Coord, tolerance2 float64, minCoords int, -) ([][][]geom.Coord, bool) { - modified := false - for i, coords2 := range coords3 { - for j, coords := range coords2 { - if c, m := removeRepeatedCoords(layout, coords, tolerance2, minCoords); m { - coords3[i][j] = c - modified = true - } - } - } - return coords3, modified -} diff --git a/postgres/parser/geo/geomfn/reverse.go b/postgres/parser/geo/geomfn/reverse.go deleted file mode 100644 index a2a657949d..0000000000 --- a/postgres/parser/geo/geomfn/reverse.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" -) - -// Reverse returns a modified geometry by reversing the order of its vertexes -func Reverse(geometry geo.Geometry) (geo.Geometry, error) { - g, err := geometry.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - - g, err = reverse(g) - if err != nil { - return geo.Geometry{}, err - } - - return geo.MakeGeometryFromGeomT(g) -} - -func reverse(g geom.T) (geom.T, error) { - if geomCollection, ok := g.(*geom.GeometryCollection); ok { - return reverseCollection(geomCollection) - } - - switch t := g.(type) { - case *geom.Point, *geom.MultiPoint: // cases where reverse does change the order - return g, nil - case *geom.LineString: - g = geom.NewLineStringFlat(t.Layout(), reverseCoords(g.FlatCoords(), g.Stride())).SetSRID(g.SRID()) - case *geom.Polygon: - g = geom.NewPolygonFlat(t.Layout(), reverseCoords(g.FlatCoords(), g.Stride()), t.Ends()).SetSRID(g.SRID()) - case *geom.MultiLineString: - g = geom.NewMultiLineStringFlat(t.Layout(), reverseMulti(g, t.Ends()), t.Ends()).SetSRID(g.SRID()) - case *geom.MultiPolygon: - var ends []int - for _, e := range t.Endss() { - ends = append(ends, e...) - } - g = geom.NewMultiPolygonFlat(t.Layout(), reverseMulti(g, ends), t.Endss()).SetSRID(g.SRID()) - - default: - return nil, geom.ErrUnsupportedType{Value: g} - } - - return g, nil -} - -func reverseCoords(coords []float64, stride int) []float64 { - for i := 0; i < len(coords)/2; i += stride { - for j := 0; j < stride; j++ { - coords[i+j], coords[len(coords)-stride-i+j] = coords[len(coords)-stride-i+j], coords[i+j] - } - } - - return coords -} - -// reverseMulti handles reversing coordinates of MULTI* geometries with nested sub-structures -func reverseMulti(g geom.T, ends []int) []float64 { - coords := g.FlatCoords() - prevEnd := 0 - - for _, end := range ends { - copy( - coords[prevEnd:end], - reverseCoords(coords[prevEnd:end], g.Stride()), - ) - prevEnd = end - } - - return coords -} - -// reverseCollection iterates through a GeometryCollection and calls reverse() on each geometry. -func reverseCollection(geomCollection *geom.GeometryCollection) (*geom.GeometryCollection, error) { - res := geom.NewGeometryCollection() - for _, subG := range geomCollection.Geoms() { - subGeom, err := reverse(subG) - if err != nil { - return nil, err - } - - if err := res.Push(subGeom); err != nil { - return nil, err - } - } - return res, nil -} diff --git a/postgres/parser/geo/geomfn/segmentize.go b/postgres/parser/geo/geomfn/segmentize.go deleted file mode 100644 index 9fe608ce9f..0000000000 --- a/postgres/parser/geo/geomfn/segmentize.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "math" - - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geosegmentize" -) - -// Segmentize return modified Geometry having no segment longer -// that given maximum segment length. -// This works by inserting the extra points in such a manner that -// minimum number of new segments with equal length is created, -// between given two-points such that each segment has length less -// than or equal to given maximum segment length. -func Segmentize(g geo.Geometry, segmentMaxLength float64) (geo.Geometry, error) { - geometry, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - switch geometry := geometry.(type) { - case *geom.Point, *geom.MultiPoint: - return g, nil - default: - if segmentMaxLength <= 0 { - return geo.Geometry{}, errors.Newf("maximum segment length must be positive") - } - segGeometry, err := geosegmentize.SegmentizeGeom(geometry, segmentMaxLength, segmentizeCoords) - if err != nil { - return geo.Geometry{}, err - } - return geo.MakeGeometryFromGeomT(segGeometry) - } -} - -// segmentizeCoords inserts multiple points between given two coordinates and -// return resultant point as flat []float64. Points are inserted in such a -// way that they create minimum number segments of equal length such that each -// segment has a length less than or equal to given maximum segment length. -// Note: List of points does not consist of end point. -func segmentizeCoords(a geom.Coord, b geom.Coord, maxSegmentLength float64) ([]float64, error) { - distanceBetweenPoints := math.Sqrt(math.Pow(a.X()-b.X(), 2) + math.Pow(b.Y()-a.Y(), 2)) - - // numberOfSegmentsToCreate represent the total number of segments - // in which given two coordinates will be divided. - numberOfSegmentsToCreate := int(math.Ceil(distanceBetweenPoints / maxSegmentLength)) - numPoints := 2 * (1 + numberOfSegmentsToCreate) - if numPoints > geosegmentize.MaxPoints { - return nil, errors.Newf( - "attempting to segmentize into too many coordinates; need %d points between %v and %v, max %d", - numPoints, - a, - b, - geosegmentize.MaxPoints, - ) - } // segmentFraction represent the fraction of length each segment - // has with respect to total length between two coordinates. - allSegmentizedCoordinates := make([]float64, 0, 2*(1+numberOfSegmentsToCreate)) - allSegmentizedCoordinates = append(allSegmentizedCoordinates, a.Clone()...) - segmentFraction := 1.0 / float64(numberOfSegmentsToCreate) - for pointInserted := 1; pointInserted < numberOfSegmentsToCreate; pointInserted++ { - allSegmentizedCoordinates = append( - allSegmentizedCoordinates, - b.X()*float64(pointInserted)*segmentFraction+a.X()*(1-float64(pointInserted)*segmentFraction), - b.Y()*float64(pointInserted)*segmentFraction+a.Y()*(1-float64(pointInserted)*segmentFraction), - ) - } - - return allSegmentizedCoordinates, nil -} diff --git a/postgres/parser/geo/geomfn/topology_operations.go b/postgres/parser/geo/geomfn/topology_operations.go deleted file mode 100644 index 24280dc89a..0000000000 --- a/postgres/parser/geo/geomfn/topology_operations.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// Centroid returns the Centroid of a given Geometry. -func Centroid(g geo.Geometry) (geo.Geometry, error) { - centroidEWKB, err := geos.Centroid(g.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(centroidEWKB) -} - -// ClipByRect clips a given Geometry by the given BoundingBox. -func ClipByRect(g geo.Geometry, b geo.CartesianBoundingBox) (geo.Geometry, error) { - if g.Empty() { - return g, nil - } - clipByRectEWKB, err := geos.ClipByRect(g.EWKB(), b.LoX, b.LoY, b.HiX, b.HiY) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(clipByRectEWKB) -} - -// ConvexHull returns the convex hull of a given Geometry. -func ConvexHull(g geo.Geometry) (geo.Geometry, error) { - convexHullEWKB, err := geos.ConvexHull(g.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(convexHullEWKB) -} - -// Simplify returns a simplified Geometry. -func Simplify(g geo.Geometry, tolerance float64) (geo.Geometry, error) { - simplifiedEWKB, err := geos.Simplify(g.EWKB(), tolerance) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(simplifiedEWKB) -} - -// SimplifyPreserveTopology returns a simplified Geometry with topology preserved. -func SimplifyPreserveTopology(g geo.Geometry, tolerance float64) (geo.Geometry, error) { - simplifiedEWKB, err := geos.TopologyPreserveSimplify(g.EWKB(), tolerance) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(simplifiedEWKB) -} - -// PointOnSurface returns the PointOnSurface of a given Geometry. -func PointOnSurface(g geo.Geometry) (geo.Geometry, error) { - pointOnSurfaceEWKB, err := geos.PointOnSurface(g.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(pointOnSurfaceEWKB) -} - -// Intersection returns the geometries of intersection between A and B. -func Intersection(a geo.Geometry, b geo.Geometry) (geo.Geometry, error) { - if a.SRID() != b.SRID() { - return geo.Geometry{}, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - retEWKB, err := geos.Intersection(a.EWKB(), b.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(retEWKB) -} - -// Union returns the geometries of union between A and B. -func Union(a geo.Geometry, b geo.Geometry) (geo.Geometry, error) { - if a.SRID() != b.SRID() { - return geo.Geometry{}, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - retEWKB, err := geos.Union(a.EWKB(), b.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(retEWKB) -} - -// SymDifference returns the geometries of symmetric difference between A and B. -func SymDifference(a geo.Geometry, b geo.Geometry) (geo.Geometry, error) { - if a.SRID() != b.SRID() { - return geo.Geometry{}, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - retEWKB, err := geos.SymDifference(a.EWKB(), b.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(retEWKB) -} - -// SharedPaths Returns a geometry collection containing paths shared by the two input geometries. -func SharedPaths(a geo.Geometry, b geo.Geometry) (geo.Geometry, error) { - if a.SRID() != b.SRID() { - return geo.Geometry{}, geo.NewMismatchingSRIDsError(a.SpatialObject(), b.SpatialObject()) - } - paths, err := geos.SharedPaths(a.EWKB(), b.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - gm, err := geo.ParseGeometryFromEWKB(paths) - if err != nil { - return geo.Geometry{}, err - } - return gm, nil -} diff --git a/postgres/parser/geo/geomfn/unary_operators.go b/postgres/parser/geo/geomfn/unary_operators.go deleted file mode 100644 index d220470d15..0000000000 --- a/postgres/parser/geo/geomfn/unary_operators.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - "github.com/twpayne/go-geom/encoding/ewkb" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// Length returns the length of a given Geometry. -// Note only (MULTI)LINESTRING objects have a length. -// (MULTI)POLYGON objects should use Perimeter. -func Length(g geo.Geometry) (float64, error) { - geomRepr, err := g.AsGeomT() - if err != nil { - return 0, err - } - // Fast path. - switch geomRepr.(type) { - case *geom.LineString, *geom.MultiLineString: - return geos.Length(g.EWKB()) - } - return lengthFromGeomT(geomRepr) -} - -// lengthFromGeomT returns the length from a geom.T, recursing down -// GeometryCollections if required. -func lengthFromGeomT(geomRepr geom.T) (float64, error) { - // Length in GEOS will also include polygon "perimeters". - // As such, gate based on on shape underneath. - switch geomRepr := geomRepr.(type) { - case *geom.Point, *geom.MultiPoint, *geom.Polygon, *geom.MultiPolygon: - return 0, nil - case *geom.LineString, *geom.MultiLineString: - ewkb, err := ewkb.Marshal(geomRepr, geo.DefaultEWKBEncodingFormat) - if err != nil { - return 0, err - } - return geos.Length(ewkb) - case *geom.GeometryCollection: - total := float64(0) - for _, subG := range geomRepr.Geoms() { - subLength, err := lengthFromGeomT(subG) - if err != nil { - return 0, err - } - total += subLength - } - return total, nil - default: - return 0, errors.AssertionFailedf("unknown geometry type: %T", geomRepr) - } -} - -// Perimeter returns the perimeter of a given Geometry. -// Note only (MULTI)POLYGON objects have a perimeter. -// (MULTI)LineString objects should use Length. -func Perimeter(g geo.Geometry) (float64, error) { - geomRepr, err := g.AsGeomT() - if err != nil { - return 0, err - } - // Fast path. - switch geomRepr.(type) { - case *geom.Polygon, *geom.MultiPolygon: - return geos.Length(g.EWKB()) - } - return perimeterFromGeomT(geomRepr) -} - -// perimeterFromGeomT returns the perimeter from a geom.T, recursing down -// GeometryCollections if required. -func perimeterFromGeomT(geomRepr geom.T) (float64, error) { - // Length in GEOS will also include polygon "perimeters". - // As such, gate based on on shape underneath. - switch geomRepr := geomRepr.(type) { - case *geom.Point, *geom.MultiPoint, *geom.LineString, *geom.MultiLineString: - return 0, nil - case *geom.Polygon, *geom.MultiPolygon: - ewkb, err := ewkb.Marshal(geomRepr, geo.DefaultEWKBEncodingFormat) - if err != nil { - return 0, err - } - return geos.Length(ewkb) - case *geom.GeometryCollection: - total := float64(0) - for _, subG := range geomRepr.Geoms() { - subLength, err := perimeterFromGeomT(subG) - if err != nil { - return 0, err - } - total += subLength - } - return total, nil - default: - return 0, errors.AssertionFailedf("unknown geometry type: %T", geomRepr) - } -} - -// Area returns the area of a given Geometry. -func Area(g geo.Geometry) (float64, error) { - return geos.Area(g.EWKB()) -} - -// Dimension returns the topological dimension of a given Geometry. -func Dimension(g geo.Geometry) (int, error) { - t, err := g.AsGeomT() - if err != nil { - return 0, err - } - return dimensionFromGeomT(t) -} - -// dimensionFromGeomT returns the dimension from a geom.T, recursing down -// GeometryCollections if required. -func dimensionFromGeomT(geomRepr geom.T) (int, error) { - switch geomRepr := geomRepr.(type) { - case *geom.Point, *geom.MultiPoint: - return 0, nil - case *geom.LineString, *geom.MultiLineString: - return 1, nil - case *geom.Polygon, *geom.MultiPolygon: - return 2, nil - case *geom.GeometryCollection: - maxDim := 0 - for _, g := range geomRepr.Geoms() { - dim, err := dimensionFromGeomT(g) - if err != nil { - return 0, err - } - if dim > maxDim { - maxDim = dim - } - } - return maxDim, nil - default: - return 0, errors.AssertionFailedf("unknown geometry type: %T", geomRepr) - } -} - -// Points returns the points of all coordinates in a geometry as a multipoint. -func Points(g geo.Geometry) (geo.Geometry, error) { - t, err := g.AsGeomT() - if err != nil { - return geo.Geometry{}, err - } - layout := t.Layout() - if gc, ok := t.(*geom.GeometryCollection); ok && gc.Empty() { - layout = geom.XY - } - points := geom.NewMultiPoint(layout).SetSRID(t.SRID()) - iter := geo.NewGeomTIterator(t, geo.EmptyBehaviorOmit) - for { - geomRepr, hasNext, err := iter.Next() - if err != nil { - return geo.Geometry{}, err - } else if !hasNext { - break - } else if geomRepr.Empty() { - continue - } - switch geomRepr := geomRepr.(type) { - case *geom.Point: - if err = pushCoord(points, geomRepr.Coords()); err != nil { - return geo.Geometry{}, err - } - case *geom.LineString: - for i := 0; i < geomRepr.NumCoords(); i++ { - if err = pushCoord(points, geomRepr.Coord(i)); err != nil { - return geo.Geometry{}, err - } - } - case *geom.Polygon: - for i := 0; i < geomRepr.NumLinearRings(); i++ { - linearRing := geomRepr.LinearRing(i) - for j := 0; j < linearRing.NumCoords(); j++ { - if err = pushCoord(points, linearRing.Coord(j)); err != nil { - return geo.Geometry{}, err - } - } - } - default: - return geo.Geometry{}, errors.AssertionFailedf("unexpected type: %T", geomRepr) - } - } - return geo.MakeGeometryFromGeomT(points) -} - -// pushCoord is a helper function for PointsFromGeomT that appends -// a coordinate to a multipoint as a point. -func pushCoord(points *geom.MultiPoint, coord geom.Coord) error { - point, err := geom.NewPoint(points.Layout()).SetCoords(coord) - if err != nil { - return err - } - return points.Push(point) -} - -// Normalize returns the geometry in its normalized form. -func Normalize(g geo.Geometry) (geo.Geometry, error) { - retEWKB, err := geos.Normalize(g.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(retEWKB) -} diff --git a/postgres/parser/geo/geomfn/unary_predicates.go b/postgres/parser/geo/geomfn/unary_predicates.go deleted file mode 100644 index d38b805796..0000000000 --- a/postgres/parser/geo/geomfn/unary_predicates.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" - - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// IsClosed returns whether the given geometry has equal start and end points. -// Collections and multi-types must contain all-closed geometries. Empty -// geometries are not closed. -func IsClosed(g geo.Geometry) (bool, error) { - t, err := g.AsGeomT() - if err != nil { - return false, err - } - return isClosedFromGeomT(t) -} - -// isClosedFromGeomT returns whether the given geom.T is closed, recursing -// into collections. -func isClosedFromGeomT(t geom.T) (bool, error) { - if t.Empty() { - return false, nil - } - switch t := t.(type) { - case *geom.Point, *geom.MultiPoint: - return true, nil - case *geom.LinearRing: - return t.Coord(0).Equal(t.Layout(), t.Coord(t.NumCoords()-1)), nil - case *geom.LineString: - return t.Coord(0).Equal(t.Layout(), t.Coord(t.NumCoords()-1)), nil - case *geom.MultiLineString: - for i := 0; i < t.NumLineStrings(); i++ { - if closed, err := isClosedFromGeomT(t.LineString(i)); err != nil || !closed { - return false, err - } - } - return true, nil - case *geom.Polygon: - for i := 0; i < t.NumLinearRings(); i++ { - if closed, err := isClosedFromGeomT(t.LinearRing(i)); err != nil || !closed { - return false, err - } - } - return true, nil - case *geom.MultiPolygon: - for i := 0; i < t.NumPolygons(); i++ { - if closed, err := isClosedFromGeomT(t.Polygon(i)); err != nil || !closed { - return false, err - } - } - return true, nil - case *geom.GeometryCollection: - for _, g := range t.Geoms() { - if closed, err := isClosedFromGeomT(g); err != nil || !closed { - return false, err - } - } - return true, nil - default: - return false, errors.AssertionFailedf("unknown geometry type: %T", t) - } -} - -// IsCollection returns whether the given geometry is of a collection type. -func IsCollection(g geo.Geometry) (bool, error) { - switch g.ShapeType() { - case geopb.ShapeType_MultiPoint, geopb.ShapeType_MultiLineString, geopb.ShapeType_MultiPolygon, - geopb.ShapeType_GeometryCollection: - return true, nil - default: - return false, nil - } -} - -// IsEmpty returns whether the given geometry is empty. -func IsEmpty(g geo.Geometry) (bool, error) { - return g.Empty(), nil -} - -// IsRing returns whether the given geometry is a ring, i.e. that it is a -// simple and closed line. -func IsRing(g geo.Geometry) (bool, error) { - // We explicitly check for empty geometries before checking the type, - // to follow PostGIS behavior where all empty geometries return false. - if g.Empty() { - return false, nil - } - if g.ShapeType() != geopb.ShapeType_LineString { - t, err := g.AsGeomT() - if err != nil { - return false, err - } - e := geom.ErrUnsupportedType{Value: t} - return false, errors.Wrap(e, "should only be called on a linear feature") - } - if closed, err := IsClosed(g); err != nil || !closed { - return false, err - } - if simple, err := IsSimple(g); err != nil || !simple { - return false, err - } - return true, nil -} - -// IsSimple returns whether the given geometry is simple, i.e. that it does not -// intersect or lie tangent to itself. -func IsSimple(g geo.Geometry) (bool, error) { - return geos.IsSimple(g.EWKB()) -} diff --git a/postgres/parser/geo/geomfn/validity_check.go b/postgres/parser/geo/geomfn/validity_check.go deleted file mode 100644 index be6eab50e5..0000000000 --- a/postgres/parser/geo/geomfn/validity_check.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geomfn - -import ( - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geos" -) - -// ValidDetail contains information about the validity of a geometry. -type ValidDetail struct { - IsValid bool - // Reason is only populated if IsValid = false. - Reason string - // InvalidLocation is only populated if IsValid = false. - InvalidLocation geo.Geometry -} - -// IsValid returns whether the given Geometry is valid. -func IsValid(g geo.Geometry) (bool, error) { - isValid, err := geos.IsValid(g.EWKB()) - if err != nil { - return false, err - } - return isValid, nil -} - -// IsValidReason returns the reasoning for whether the Geometry is valid or invalid. -func IsValidReason(g geo.Geometry) (string, error) { - reason, err := geos.IsValidReason(g.EWKB()) - if err != nil { - return "", err - } - return reason, nil -} - -// IsValidDetail returns information about the validity of a Geometry. -// It takes in a flag parameter which behaves the same as the GEOS module, where 1 -// means that self-intersecting rings forming holes are considered valid. -func IsValidDetail(g geo.Geometry, flags int) (ValidDetail, error) { - isValid, reason, locEWKB, err := geos.IsValidDetail(g.EWKB(), flags) - if err != nil { - return ValidDetail{}, err - } - var loc geo.Geometry - if len(locEWKB) > 0 { - loc, err = geo.ParseGeometryFromEWKB(locEWKB) - if err != nil { - return ValidDetail{}, err - } - } - return ValidDetail{ - IsValid: isValid, - Reason: reason, - InvalidLocation: loc, - }, nil -} - -// MakeValid returns a valid form of the given Geometry. -func MakeValid(g geo.Geometry) (geo.Geometry, error) { - validEWKB, err := geos.MakeValid(g.EWKB()) - if err != nil { - return geo.Geometry{}, err - } - return geo.ParseGeometryFromEWKB(validEWKB) -} diff --git a/postgres/parser/geo/geosegmentize/geosegmentize.go b/postgres/parser/geo/geosegmentize/geosegmentize.go deleted file mode 100644 index c5ce579b75..0000000000 --- a/postgres/parser/geo/geosegmentize/geosegmentize.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package geosegmentize - -import ( - "github.com/cockroachdb/errors" - "github.com/twpayne/go-geom" -) - -// MaxPoints is the maximum number of points segmentize is allowed to generate. -const MaxPoints = 16336 - -// SegmentizeGeom returns a modified geom.T having no segment longer -// than the given maximum segment length. -// segmentMaxAngleOrLength represents two different things depending -// on the object, which is about to segmentize as in case of geography -// it represents maximum segment angle whereas, in case of geometry it -// represents maximum segment distance. -// segmentizeCoords represents the function's definition which allows -// us to segmentize given two-points. We have to specify segmentizeCoords -// explicitly, as the algorithm for segmentization is significantly -// different for geometry and geography. -func SegmentizeGeom( - geometry geom.T, - segmentMaxAngleOrLength float64, - segmentizeCoords func(geom.Coord, geom.Coord, float64) ([]float64, error), -) (geom.T, error) { - if geometry.Empty() { - return geometry, nil - } - switch geometry := geometry.(type) { - case *geom.Point, *geom.MultiPoint: - return geometry, nil - case *geom.LineString: - var allFlatCoordinates []float64 - for pointIdx := 1; pointIdx < geometry.NumCoords(); pointIdx++ { - coords, err := segmentizeCoords(geometry.Coord(pointIdx-1), geometry.Coord(pointIdx), segmentMaxAngleOrLength) - if err != nil { - return nil, err - } - allFlatCoordinates = append( - allFlatCoordinates, - coords..., - ) - } - // Appending end point as it wasn't included in the iteration of coordinates. - allFlatCoordinates = append(allFlatCoordinates, geometry.Coord(geometry.NumCoords()-1)...) - return geom.NewLineStringFlat(geom.XY, allFlatCoordinates).SetSRID(geometry.SRID()), nil - case *geom.MultiLineString: - segMultiLine := geom.NewMultiLineString(geom.XY).SetSRID(geometry.SRID()) - for lineIdx := 0; lineIdx < geometry.NumLineStrings(); lineIdx++ { - l, err := SegmentizeGeom(geometry.LineString(lineIdx), segmentMaxAngleOrLength, segmentizeCoords) - if err != nil { - return nil, err - } - err = segMultiLine.Push(l.(*geom.LineString)) - if err != nil { - return nil, err - } - } - return segMultiLine, nil - case *geom.LinearRing: - var allFlatCoordinates []float64 - for pointIdx := 1; pointIdx < geometry.NumCoords(); pointIdx++ { - coords, err := segmentizeCoords(geometry.Coord(pointIdx-1), geometry.Coord(pointIdx), segmentMaxAngleOrLength) - if err != nil { - return nil, err - } - allFlatCoordinates = append( - allFlatCoordinates, - coords..., - ) - } - // Appending end point as it wasn't included in the iteration of coordinates. - allFlatCoordinates = append(allFlatCoordinates, geometry.Coord(geometry.NumCoords()-1)...) - return geom.NewLinearRingFlat(geom.XY, allFlatCoordinates).SetSRID(geometry.SRID()), nil - case *geom.Polygon: - segPolygon := geom.NewPolygon(geom.XY).SetSRID(geometry.SRID()) - for loopIdx := 0; loopIdx < geometry.NumLinearRings(); loopIdx++ { - l, err := SegmentizeGeom(geometry.LinearRing(loopIdx), segmentMaxAngleOrLength, segmentizeCoords) - if err != nil { - return nil, err - } - err = segPolygon.Push(l.(*geom.LinearRing)) - if err != nil { - return nil, err - } - } - return segPolygon, nil - case *geom.MultiPolygon: - segMultiPolygon := geom.NewMultiPolygon(geom.XY).SetSRID(geometry.SRID()) - for polygonIdx := 0; polygonIdx < geometry.NumPolygons(); polygonIdx++ { - p, err := SegmentizeGeom(geometry.Polygon(polygonIdx), segmentMaxAngleOrLength, segmentizeCoords) - if err != nil { - return nil, err - } - err = segMultiPolygon.Push(p.(*geom.Polygon)) - if err != nil { - return nil, err - } - } - return segMultiPolygon, nil - case *geom.GeometryCollection: - segGeomCollection := geom.NewGeometryCollection().SetSRID(geometry.SRID()) - for geoIdx := 0; geoIdx < geometry.NumGeoms(); geoIdx++ { - g, err := SegmentizeGeom(geometry.Geom(geoIdx), segmentMaxAngleOrLength, segmentizeCoords) - if err != nil { - return nil, err - } - err = segGeomCollection.Push(g) - if err != nil { - return nil, err - } - } - return segGeomCollection, nil - } - return nil, errors.Newf("unknown type: %T", geometry) -} diff --git a/postgres/parser/interval/btree_based_interval.go b/postgres/parser/interval/btree_based_interval.go deleted file mode 100644 index 27a2fbca8c..0000000000 --- a/postgres/parser/interval/btree_based_interval.go +++ /dev/null @@ -1,1156 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. -// -// This code is based on: https://github.com/google/btree. - -package interval - -import ( - "sort" - - "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/syncutil" -) - -const ( - // DefaultBTreeMinimumDegree is the default B-tree minimum degree. Benchmarks - // show that the interval tree performs best with this minimum degree. - DefaultBTreeMinimumDegree = 32 - // DefaultBTreeFreeListSize is the default size of a B-tree's freelist. - DefaultBTreeFreeListSize = 32 -) - -var ( - nilItems = make(items, 16) - nilChildren = make(children, 16) -) - -// FreeList represents a free list of btree nodes. By default each -// BTree has its own FreeList, but multiple BTrees can share the same -// FreeList. -// Two Btrees using the same freelist are safe for concurrent write access. -type FreeList struct { - mu syncutil.Mutex - freelist []*node -} - -// NewFreeList creates a new free list. -// size is the maximum size of the returned free list. -func NewFreeList(size int) *FreeList { - return &FreeList{freelist: make([]*node, 0, size)} -} - -func (f *FreeList) newNode() (n *node) { - f.mu.Lock() - index := len(f.freelist) - 1 - if index < 0 { - f.mu.Unlock() - return new(node) - } - n = f.freelist[index] - f.freelist[index] = nil - f.freelist = f.freelist[:index] - f.mu.Unlock() - return -} - -// freeNode adds the given node to the list, returning true if it was added -// and false if it was discarded. -func (f *FreeList) freeNode(n *node) (out bool) { - f.mu.Lock() - if len(f.freelist) < cap(f.freelist) { - f.freelist = append(f.freelist, n) - out = true - } - f.mu.Unlock() - return -} - -// newBTree creates a new interval tree with the given overlapper function and -// the default B-tree minimum degree. -func newBTree(overlapper Overlapper) *btree { - return newBTreeWithDegree(overlapper, DefaultBTreeMinimumDegree) -} - -// newBTreeWithDegree creates a new interval tree with the given overlapper -// function and the given minimum degree. A minimum degree less than 2 will -// cause a panic. -// -// newBTreeWithDegree(overlapper, 2), for example, will create a 2-3-4 tree (each -// node contains 1-3 Interfaces and 2-4 children). -func newBTreeWithDegree(overlapper Overlapper, minimumDegree int) *btree { - if minimumDegree < 2 { - panic("bad minimum degree") - } - f := NewFreeList(DefaultBTreeFreeListSize) - return &btree{ - minimumDegree: minimumDegree, - overlapper: overlapper, - cow: ©OnWriteContext{freelist: f}, - } -} - -func isValidInterface(a Interface) error { - if a == nil { - // Note: Newf instead of New so that the error message is revealed - // in redact calls. - return errors.Newf("nil interface") - } - r := a.Range() - return rangeError(r) -} - -// interfaces stores Interfaces sorted by Range().End in a node. -type items []Interface - -// insertAt inserts a value into the given index, pushing all subsequent values -// forward. -func (s *items) insertAt(index int, e Interface) { - oldLen := len(*s) - *s = append(*s, nil) - if index < oldLen { - copy((*s)[index+1:], (*s)[index:]) - } - (*s)[index] = e -} - -// removeAt removes a value at a given index, pulling all subsequent values -// back. -func (s *items) removeAt(index int) Interface { - e := (*s)[index] - copy((*s)[index:], (*s)[index+1:]) - (*s)[len(*s)-1] = nil - *s = (*s)[:len(*s)-1] - return e -} - -// pop removes and returns the last element in the list. -func (s *items) pop() (out Interface) { - index := len(*s) - 1 - out = (*s)[index] - (*s)[index] = nil - *s = (*s)[:index] - return -} - -// truncate truncates this instance at index so that it contains only the -// first index items. index must be less than or equal to length. -func (s *items) truncate(index int) { - var toClear items - *s, toClear = (*s)[:index], (*s)[index:] - for len(toClear) > 0 { - toClear = toClear[copy(toClear, nilItems):] - } -} - -// find returns the index where the given Interface should be inserted into this -// list. 'found' is true if the interface already exists in the list at the -// given index. -func (s items) find(e Interface) (index int, found bool) { - i := sort.Search(len(s), func(i int) bool { - return Compare(e, s[i]) < 0 - }) - if i > 0 && Equal(s[i-1], e) { - return i - 1, true - } - return i, false -} - -// children stores child nodes sorted by Range.End in a node. -type children []*node - -// insertAt inserts a value into the given index, pushing all subsequent values -// forward. -func (s *children) insertAt(index int, n *node) { - oldLen := len(*s) - *s = append(*s, nil) - if index < oldLen { - copy((*s)[index+1:], (*s)[index:]) - } - (*s)[index] = n -} - -// removeAt removes a value at a given index, pulling all subsequent values -// back. -func (s *children) removeAt(index int) *node { - n := (*s)[index] - copy((*s)[index:], (*s)[index+1:]) - (*s)[len(*s)-1] = nil - *s = (*s)[:len(*s)-1] - return n -} - -// pop removes and returns the last element in the list. -func (s *children) pop() (out *node) { - index := len(*s) - 1 - out = (*s)[index] - (*s)[index] = nil - *s = (*s)[:index] - return -} - -// truncate truncates this instance at index so that it contains only the -// first index children. index must be less than or equal to length. -func (s *children) truncate(index int) { - var toClear children - *s, toClear = (*s)[:index], (*s)[index:] - for len(toClear) > 0 { - toClear = toClear[copy(toClear, nilChildren):] - } -} - -// node is an internal node in a tree. -// -// It must at all times maintain the invariant that either -// * len(children) == 0, len(interfaces) unconstrained -// * len(children) == len(interfaces) + 1 -type node struct { - // Range is the node range which covers all the ranges in the subtree rooted - // at the node. Range.Start is the leftmost position. Range.End is the - // rightmost position. Here we follow the approach employed by - // https://github.com/biogo/store/tree/master/interval since it make it easy - // to analyze the traversal of intervals which overlaps with a given interval. - // CLRS only uses Range.End. - Range Range - items items - children children - cow *copyOnWriteContext -} - -func (n *node) mutableFor(cow *copyOnWriteContext) *node { - if n.cow == cow { - return n - } - out := cow.newNode() - out.Range = n.Range - if cap(out.items) >= len(n.items) { - out.items = out.items[:len(n.items)] - } else { - out.items = make(items, len(n.items), cap(n.items)) - } - copy(out.items, n.items) - // Copy children - if cap(out.children) >= len(n.children) { - out.children = out.children[:len(n.children)] - } else { - out.children = make(children, len(n.children), cap(n.children)) - } - copy(out.children, n.children) - return out -} - -func (n *node) mutableChild(i int) *node { - c := n.children[i].mutableFor(n.cow) - n.children[i] = c - return c -} - -// split splits the given node at the given index. The current node shrinks, and -// this function returns the Interface that existed at that index and a new node -// containing all interfaces/children after it. Before splitting: -// -// +-----------+ -// | x y z | -// ---/-/-\-\--+ -// -// After splitting: -// -// +-----------+ -// | y | -// -----/-\----+ -// / \ -// v v -// +-----------+ +-----------+ -// | x | | z | -// +-----------+ +-----------+ -// -func (n *node) split(i int, fast bool) (Interface, *node) { - e := n.items[i] - second := n.cow.newNode() - second.items = append(second.items, n.items[i+1:]...) - n.items.truncate(i) - if len(n.children) > 0 { - second.children = append(second.children, n.children[i+1:]...) - n.children.truncate(i + 1) - } - if !fast { - // adjust range for the first split part - oldRangeEnd := n.Range.End - n.Range.End = n.rangeEnd() - - // adjust range for the second split part - second.Range.Start = second.rangeStart() - if n.Range.End.Equal(oldRangeEnd) || e.Range().End.Equal(oldRangeEnd) { - second.Range.End = second.rangeEnd() - } else { - second.Range.End = oldRangeEnd - } - } - return e, second -} - -// maybeSplitChild checks if a child should be split, and if so splits it. -// Returns whether or not a split occurred. -func (n *node) maybeSplitChild(i, maxItems int, fast bool) bool { - if len(n.children[i].items) < maxItems { - return false - } - first := n.mutableChild(i) - e, second := first.split(maxItems/2, fast) - n.items.insertAt(i, e) - n.children.insertAt(i+1, second) - return true -} - -// insert inserts an Interface into the subtree rooted at this node, making sure -// no nodes in the subtree exceed maxItems Interfaces. -func (n *node) insert(e Interface, maxItems int, fast bool) (out Interface, extended bool) { - i, found := n.items.find(e) - if found { - out = n.items[i] - n.items[i] = e - return - } - if len(n.children) == 0 { - n.items.insertAt(i, e) - out = nil - if !fast { - if i == 0 { - extended = true - n.Range.Start = n.items[0].Range().Start - } - if n.items[i].Range().End.Compare(n.Range.End) > 0 { - extended = true - n.Range.End = n.items[i].Range().End - } - } - return - } - if n.maybeSplitChild(i, maxItems, fast) { - inTree := n.items[i] - switch Compare(e, inTree) { - case -1: - // no change, we want first split node - case 1: - i++ // we want second split node - default: - out = n.items[i] - n.items[i] = e - return - } - } - out, extended = n.mutableChild(i).insert(e, maxItems, fast) - if !fast && extended { - extended = false - if i == 0 && n.children[0].Range.Start.Compare(n.Range.Start) < 0 { - extended = true - n.Range.Start = n.children[0].Range.Start - } - if n.children[i].Range.End.Compare(n.Range.End) > 0 { - extended = true - n.Range.End = n.children[i].Range.End - } - } - return -} - -func (t *btree) isEmpty() bool { - return t.root == nil || len(t.root.items) == 0 -} - -func (t *btree) Get(r Range) (o []Interface) { - return t.GetWithOverlapper(r, t.overlapper) -} - -func (t *btree) GetWithOverlapper(r Range, overlapper Overlapper) (o []Interface) { - if err := rangeError(r); err != nil { - return - } - if !t.overlappable(r) { - return - } - t.root.doMatch(func(e Interface) (done bool) { o = append(o, e); return }, r, overlapper) - return -} - -func (t *btree) DoMatching(fn Operation, r Range) bool { - if err := rangeError(r); err != nil { - return false - } - if !t.overlappable(r) { - return false - } - return t.root.doMatch(fn, r, t.overlapper) -} - -func (t *btree) overlappable(r Range) bool { - if t.isEmpty() || !t.overlapper.Overlap(r, t.root.Range) { - return false - } - return true -} - -// benchmarks show that if Comparable.Compare is invoked directly instead of -// through an indirection with Overlapper, Insert, Delete and a traversal to -// visit overlapped intervals have a noticeable speed-up. So two versions of -// doMatch are created. One is for InclusiveOverlapper. The other is for -// ExclusiveOverlapper. -func (n *node) doMatch(fn Operation, r Range, overlapper Overlapper) (done bool) { - if overlapper == InclusiveOverlapper { - return n.inclusiveDoMatch(fn, r, overlapper) - } - return n.exclusiveDoMatch(fn, r, overlapper) -} - -// doMatch for InclusiveOverlapper. -func (n *node) inclusiveDoMatch(fn Operation, r Range, overlapper Overlapper) (done bool) { - length := sort.Search(len(n.items), func(i int) bool { - return n.items[i].Range().Start.Compare(r.End) > 0 - }) - - if len(n.children) == 0 { - for _, e := range n.items[:length] { - if r.Start.Compare(e.Range().End) <= 0 { - if done = fn(e); done { - return - } - } - } - return - } - - for i := 0; i < length; i++ { - c := n.children[i] - if r.Start.Compare(c.Range.End) <= 0 { - if done = c.inclusiveDoMatch(fn, r, overlapper); done { - return - } - } - e := n.items[i] - if r.Start.Compare(e.Range().End) <= 0 { - if done = fn(e); done { - return - } - } - } - - if overlapper.Overlap(r, n.children[length].Range) { - done = n.children[length].inclusiveDoMatch(fn, r, overlapper) - } - return -} - -// doMatch for ExclusiveOverlapper. -func (n *node) exclusiveDoMatch(fn Operation, r Range, overlapper Overlapper) (done bool) { - length := sort.Search(len(n.items), func(i int) bool { - return n.items[i].Range().Start.Compare(r.End) >= 0 - }) - - if len(n.children) == 0 { - for _, e := range n.items[:length] { - if r.Start.Compare(e.Range().End) < 0 { - if done = fn(e); done { - return - } - } - } - return - } - - for i := 0; i < length; i++ { - c := n.children[i] - if r.Start.Compare(c.Range.End) < 0 { - if done = c.exclusiveDoMatch(fn, r, overlapper); done { - return - } - } - e := n.items[i] - if r.Start.Compare(e.Range().End) < 0 { - if done = fn(e); done { - return - } - } - } - - if overlapper.Overlap(r, n.children[length].Range) { - done = n.children[length].exclusiveDoMatch(fn, r, overlapper) - } - return -} - -func (t *btree) Do(fn Operation) bool { - if t.root == nil { - return false - } - return t.root.do(fn) -} - -func (n *node) do(fn Operation) (done bool) { - cLen := len(n.children) - if cLen == 0 { - for _, e := range n.items { - if done = fn(e); done { - return - } - } - return - } - - for i := 0; i < cLen-1; i++ { - c := n.children[i] - if done = c.do(fn); done { - return - } - e := n.items[i] - if done = fn(e); done { - return - } - } - done = n.children[cLen-1].do(fn) - return -} - -// toRemove details what interface to remove in a node.remove call. -type toRemove int - -const ( - removeItem toRemove = iota // removes the given interface - removeMin // removes smallest interface in the subtree - removeMax // removes largest interface in the subtree -) - -// remove removes an interface from the subtree rooted at this node. -func (n *node) remove( - e Interface, minItems int, typ toRemove, fast bool, -) (out Interface, shrunk bool) { - var i int - var found bool - switch typ { - case removeMax: - if len(n.children) == 0 { - return n.removeFromLeaf(len(n.items)-1, fast) - } - i = len(n.items) - case removeMin: - if len(n.children) == 0 { - return n.removeFromLeaf(0, fast) - } - i = 0 - case removeItem: - i, found = n.items.find(e) - if len(n.children) == 0 { - if found { - return n.removeFromLeaf(i, fast) - } - return - } - default: - panic("invalid remove type") - } - // If we get to here, we have children. - if len(n.children[i].items) <= minItems { - out, shrunk = n.growChildAndRemove(i, e, minItems, typ, fast) - return - } - child := n.mutableChild(i) - // Either we had enough interfaces to begin with, or we've done some - // merging/stealing, because we've got enough now and we're ready to return - // stuff. - if found { - // The interface exists at index 'i', and the child we've selected can give - // us a predecessor, since if we've gotten here it's got > minItems - // interfaces in it. - out = n.items[i] - // We use our special-case 'remove' call with typ=removeMax to pull the - // predecessor of interface i (the rightmost leaf of our immediate left - // child) and set it into where we pulled the interface from. - n.items[i], _ = child.remove(nil, minItems, removeMax, fast) - if !fast { - shrunk = n.adjustRangeEndForRemoval(out, nil) - } - return - } - // Final recursive call. Once we're here, we know that the interface isn't in - // this node and that the child is big enough to remove from. - out, shrunk = child.remove(e, minItems, typ, fast) - if !fast && shrunk { - shrunkOnStart := false - if i == 0 { - if n.Range.Start.Compare(child.Range.Start) < 0 { - shrunkOnStart = true - n.Range.Start = child.Range.Start - } - } - shrunkOnEnd := n.adjustRangeEndForRemoval(out, nil) - shrunk = shrunkOnStart || shrunkOnEnd - } - return -} - -// adjustRangeEndForRemoval adjusts Range.End for the node after an interface -// and/or a child is removed. -func (n *node) adjustRangeEndForRemoval(e Interface, c *node) (decreased bool) { - if (e != nil && e.Range().End.Equal(n.Range.End)) || (c != nil && c.Range.End.Equal(n.Range.End)) { - newEnd := n.rangeEnd() - if n.Range.End.Compare(newEnd) > 0 { - decreased = true - n.Range.End = newEnd - } - } - return -} - -// removeFromLeaf removes children[i] from the leaf node. -func (n *node) removeFromLeaf(i int, fast bool) (out Interface, shrunk bool) { - if i == len(n.items)-1 { - out = n.items.pop() - } else { - out = n.items.removeAt(i) - } - if !fast && len(n.items) > 0 { - shrunkOnStart := false - if i == 0 { - oldStart := n.Range.Start - n.Range.Start = n.items[0].Range().Start - if !n.Range.Start.Equal(oldStart) { - shrunkOnStart = true - } - } - shrunkOnEnd := n.adjustRangeEndForRemoval(out, nil) - shrunk = shrunkOnStart || shrunkOnEnd - } - return -} - -// growChildAndRemove grows child 'i' to make sure it's possible to remove an -// Interface from it while keeping it at minItems, then calls remove to -// actually remove it. -// -// Most documentation says we have to do two sets of special casing: -// 1) interface is in this node -// 2) interface is in child -// In both cases, we need to handle the two subcases: -// A) node has enough values that it can spare one -// B) node doesn't have enough values -// For the latter, we have to check: -// a) left sibling has node to spare -// b) right sibling has node to spare -// c) we must merge -// To simplify our code here, we handle cases #1 and #2 the same: -// If a node doesn't have enough Interfaces, we make sure it does (using a,b,c). -// We then simply redo our remove call, and the second time (regardless of -// whether we're in case 1 or 2), we'll have enough Interfaces and can guarantee -// that we hit case A. -func (n *node) growChildAndRemove( - i int, e Interface, minItems int, typ toRemove, fast bool, -) (out Interface, shrunk bool) { - if i > 0 && len(n.children[i-1].items) > minItems { - n.stealFromLeftChild(i, fast) - } else if i < len(n.items) && len(n.children[i+1].items) > minItems { - n.stealFromRightChild(i, fast) - } else { - if i >= len(n.items) { - i-- - } - n.mergeWithRightChild(i, fast) - } - return n.remove(e, minItems, typ, fast) -} - -// Steal from left child. Before stealing: -// -// +-----------+ -// | y | -// -----/-\----+ -// / \ -// v v -// +-----------+ +-----------+ -// | x | | | -// +----------\+ +-----------+ -// \ -// v -// a -// -// After stealing: -// -// +-----------+ -// | x | -// -----/-\----+ -// / \ -// v v -// +-----------+ +-----------+ -// | | | y | -// +-----------+ +/----------+ -// / -// v -// a -// -func (n *node) stealFromLeftChild(i int, fast bool) { - // steal - stealTo := n.mutableChild(i) - stealFrom := n.mutableChild(i - 1) - x := stealFrom.items.pop() - y := n.items[i-1] - stealTo.items.insertAt(0, y) - n.items[i-1] = x - var a *node - if len(stealFrom.children) > 0 { - a = stealFrom.children.pop() - stealTo.children.insertAt(0, a) - } - - if !fast { - // adjust range for stealFrom - stealFrom.adjustRangeEndForRemoval(x, a) - - // adjust range for stealTo - stealTo.Range.Start = stealTo.rangeStart() - if y.Range().End.Compare(stealTo.Range.End) > 0 { - stealTo.Range.End = y.Range().End - } - if a != nil && a.Range.End.Compare(stealTo.Range.End) > 0 { - stealTo.Range.End = a.Range.End - } - } -} - -// Steal from right child. Before stealing: -// -// +-----------+ -// | y | -// -----/-\----+ -// / \ -// v v -// +-----------+ +-----------+ -// | | | x | -// +---------- + +/----------+ -// / -// v -// a -// -// After stealing: -// -// +-----------+ -// | x | -// -----/-\----+ -// / \ -// v v -// +-----------+ +-----------+ -// | y | | | -// +----------\+ +-----------+ -// \ -// v -// a -// -func (n *node) stealFromRightChild(i int, fast bool) { - // steal - stealTo := n.mutableChild(i) - stealFrom := n.mutableChild(i + 1) - x := stealFrom.items.removeAt(0) - y := n.items[i] - stealTo.items = append(stealTo.items, y) - n.items[i] = x - var a *node - if len(stealFrom.children) > 0 { - a = stealFrom.children.removeAt(0) - stealTo.children = append(stealTo.children, a) - } - - if !fast { - // adjust range for stealFrom - stealFrom.Range.Start = stealFrom.rangeStart() - stealFrom.adjustRangeEndForRemoval(x, a) - - // adjust range for stealTo - if y.Range().End.Compare(stealTo.Range.End) > 0 { - stealTo.Range.End = y.Range().End - } - if a != nil && a.Range.End.Compare(stealTo.Range.End) > 0 { - stealTo.Range.End = a.Range.End - } - } -} - -// Merge with right child. Before merging: -// -// +-----------+ -// | u y v | -// -----/-\----+ -// / \ -// v v -// +-----------+ +-----------+ -// | x | | z | -// +---------- + +-----------+ -// -// After merging: -// -// +-----------+ -// | u v | -// ------|-----+ -// | -// v -// +-----------+ -// | x y z | -// +---------- + -// -func (n *node) mergeWithRightChild(i int, fast bool) { - // merge - child := n.mutableChild(i) - mergeItem := n.items.removeAt(i) - mergeChild := n.children.removeAt(i + 1) - child.items = append(child.items, mergeItem) - child.items = append(child.items, mergeChild.items...) - child.children = append(child.children, mergeChild.children...) - - if !fast { - if mergeItem.Range().End.Compare(child.Range.End) > 0 { - child.Range.End = mergeItem.Range().End - } - if mergeChild.Range.End.Compare(child.Range.End) > 0 { - child.Range.End = mergeChild.Range.End - } - } - n.cow.freeNode(mergeChild) -} - -var _ Tree = (*btree)(nil) - -// btree is an interval tree based on an augmented BTree. -// -// Tree stores Instances in an ordered structure, allowing easy insertion, -// removal, and iteration. -// -// Write operations are not safe for concurrent mutation by multiple -// goroutines, but Read operations are. -type btree struct { - length int - minimumDegree int - overlapper Overlapper - root *node - cow *copyOnWriteContext -} - -// copyOnWriteContext pointers determine node ownership... a tree with a write -// context equivalent to a node's write context is allowed to modify that node. -// A tree whose write context does not match a node's is not allowed to modify -// it, and must create a new, writable copy (IE: it's a Clone). -// -// When doing any write operation, we maintain the invariant that the current -// node's context is equal to the context of the tree that requested the write. -// We do this by, before we descend into any node, creating a copy with the -// correct context if the contexts don't match. -// -// Since the node we're currently visiting on any write has the requesting -// tree's context, that node is modifiable in place. Children of that node may -// not share context, but before we descend into them, we'll make a mutable -// copy. -type copyOnWriteContext struct { - freelist *FreeList -} - -// cloneInternal clones the btree, lazily. Clone should not be called concurrently, -// but the original tree (t) and the new tree (t2) can be used concurrently -// once the Clone call completes. -// -// The internal tree structure of b is marked read-only and shared between t and -// t2. Writes to both t and t2 use copy-on-write logic, creating new nodes -// whenever one of b's original nodes would have been modified. Read operations -// should have no performance degredation. Write operations for both t and t2 -// will initially experience minor slow-downs caused by additional allocs and -// copies due to the aforementioned copy-on-write logic, but should converge to -// the original performance characteristics of the original tree. -func (t *btree) cloneInternal() (t2 *btree) { - // Create two entirely new copy-on-write contexts. - // This operation effectively creates three trees: - // the original, shared nodes (old b.cow) - // the new b.cow nodes - // the new out.cow nodes - cow1, cow2 := *t.cow, *t.cow - out := *t - t.cow = &cow1 - out.cow = &cow2 - return &out -} - -// Clone clones the btree, lazily. -func (t *btree) Clone() Tree { - return t.cloneInternal() -} - -// adjustRange sets the Range to the maximum extent of the childrens' Range -// spans and its range spans. -func (n *node) adjustRange() { - n.Range.Start = n.rangeStart() - n.Range.End = n.rangeEnd() -} - -// rangeStart returns the leftmost position for the node range, assuming that -// its children have correct range extents. -func (n *node) rangeStart() Comparable { - minStart := n.items[0].Range().Start - if len(n.children) > 0 { - minStart = n.children[0].Range.Start - } - return minStart -} - -// rangeEnd returns the rightmost position for the node range, assuming that its -// children have correct range extents. -func (n *node) rangeEnd() Comparable { - if len(n.items) == 0 { - maxEnd := n.children[0].Range.End - for _, c := range n.children[1:] { - if end := c.Range.End; maxEnd.Compare(end) < 0 { - maxEnd = end - } - } - return maxEnd - } - maxEnd := n.items[0].Range().End - for _, e := range n.items[1:] { - if end := e.Range().End; maxEnd.Compare(end) < 0 { - maxEnd = end - } - } - for _, c := range n.children { - if end := c.Range.End; maxEnd.Compare(end) < 0 { - maxEnd = end - } - } - return maxEnd -} - -func (t *btree) AdjustRanges() { - if t.isEmpty() { - return - } - t.root.adjustRanges(t.root.cow) -} - -func (n *node) adjustRanges(c *copyOnWriteContext) { - if n.cow != c { - // Could not have been modified. - return - } - for _, child := range n.children { - child.adjustRanges(c) - } - n.adjustRange() -} - -// maxItems returns the max number of Interfaces to allow per node. -func (t *btree) maxItems() int { - return t.minimumDegree*2 - 1 -} - -// minItems returns the min number of Interfaces to allow per node (ignored -// for the root node). -func (t *btree) minItems() int { - return t.minimumDegree - 1 -} - -func (c *copyOnWriteContext) newNode() (n *node) { - n = c.freelist.newNode() - n.cow = c - return -} - -type freeType int - -const ( - ftFreelistFull freeType = iota // node was freed (available for GC, not stored in freelist) - ftStored // node was stored in the freelist for later use - ftNotOwned // node was ignored by COW, since it's owned by another one -) - -// freeNode frees a node within a given COW context, if it's owned by that -// context. It returns what happened to the node (see freeType const -// documentation). -func (c *copyOnWriteContext) freeNode(n *node) freeType { - if n.cow == c { - // clear to allow GC - n.items.truncate(0) - n.children.truncate(0) - n.cow = nil // clear to allow GC - if c.freelist.freeNode(n) { - return ftStored - } - return ftFreelistFull - } - return ftNotOwned -} - -func (t *btree) Insert(e Interface, fast bool) (err error) { - // t.metrics("Insert") - if err = isValidInterface(e); err != nil { - return - } - - if t.root == nil { - t.root = t.cow.newNode() - t.root.items = append(t.root.items, e) - t.length++ - if !fast { - t.root.Range.Start = e.Range().Start - t.root.Range.End = e.Range().End - } - return nil - } - - t.root = t.root.mutableFor(t.cow) - if len(t.root.items) >= t.maxItems() { - oldroot := t.root - t.root = t.cow.newNode() - if !fast { - t.root.Range.Start = oldroot.Range.Start - t.root.Range.End = oldroot.Range.End - } - e2, second := oldroot.split(t.maxItems()/2, fast) - t.root.items = append(t.root.items, e2) - t.root.children = append(t.root.children, oldroot, second) - } - - out, _ := t.root.insert(e, t.maxItems(), fast) - if out == nil { - t.length++ - } - return -} - -func (t *btree) Delete(e Interface, fast bool) (err error) { - // t.metrics("Delete") - if err = isValidInterface(e); err != nil { - return - } - if !t.overlappable(e.Range()) { - return - } - t.delete(e, removeItem, fast) - return -} - -func (t *btree) delete(e Interface, typ toRemove, fast bool) Interface { - t.root = t.root.mutableFor(t.cow) - out, _ := t.root.remove(e, t.minItems(), typ, fast) - if len(t.root.items) == 0 && len(t.root.children) > 0 { - oldroot := t.root - t.root = t.root.children[0] - t.cow.freeNode(oldroot) - } - if out != nil { - t.length-- - } - return out -} - -func (t *btree) Len() int { - return t.length -} - -type stackElem struct { - node *node - index int -} - -type btreeIterator struct { - stack []*stackElem -} - -func (ti *btreeIterator) Next() (i Interface, ok bool) { - if len(ti.stack) == 0 { - return nil, false - } - elem := ti.stack[len(ti.stack)-1] - curItem := elem.node.items[elem.index] - elem.index++ - if elem.index >= len(elem.node.items) { - ti.stack = ti.stack[:len(ti.stack)-1] - } - if len(elem.node.children) > 0 { - for r := elem.node.children[elem.index]; r != nil; r = r.children[0] { - ti.stack = append(ti.stack, &stackElem{r, 0}) - if len(r.children) == 0 { - break - } - } - } - return curItem, true -} - -func (t *btree) Iterator() TreeIterator { - var ti btreeIterator - for n := t.root; n != nil; n = n.children[0] { - ti.stack = append(ti.stack, &stackElem{n, 0}) - if len(n.children) == 0 { - break - } - } - return &ti -} - -// ClearWithOpt removes all items from the btree. If addNodesToFreelist is -// true, t's nodes are added to its freelist as part of this call, until the -// freelist is full. Otherwise, the root node is simply dereferenced and the -// subtree left to Go's normal GC processes. -// -// This can be much faster than calling Delete on all elements, because that -// requires finding/removing each element in the tree and updating the tree -// accordingly. It also is somewhat faster than creating a new tree to replace -// the old one, because nodes from the old tree are reclaimed into the freelist -// for use by the new one, instead of being lost to the garbage collector. -// -// This call takes: -// O(1): when addNodesToFreelist is false, this is a single operation. -// O(1): when the freelist is already full, it breaks out immediately -// O(freelist size): when the freelist is empty and the nodes are all owned -// by this tree, nodes are added to the freelist until full. -// O(tree size): when all nodes are owned by another tree, all nodes are -// iterated over looking for nodes to add to the freelist, and due to -// ownership, none are. -func (t *btree) ClearWithOpt(addNodesToFreelist bool) { - if t.root != nil && addNodesToFreelist { - t.root.reset(t.cow) - } - t.root, t.length = nil, 0 -} - -func (t *btree) Clear() { - t.ClearWithOpt(true) -} - -// reset returns a subtree to the freelist. It breaks out immediately if the -// freelist is full, since the only benefit of iterating is to fill that -// freelist up. Returns true if parent reset call should continue. -func (n *node) reset(c *copyOnWriteContext) bool { - if n.cow != c { - return false - } - for _, child := range n.children { - if !child.reset(c) { - return false - } - } - return c.freeNode(n) != ftFreelistFull -} diff --git a/postgres/parser/interval/bu23.go b/postgres/parser/interval/bu23.go deleted file mode 100644 index c4374fa917..0000000000 --- a/postgres/parser/interval/bu23.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2019 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// Copyright ©2014 The bíogo Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in licenses/BSD-biogo.txt. - -// This code originated in the github.com/biogo/store/interval package. - -//go:build !td234 -// +build !td234 - -package interval - -// LLRBMode . -const LLRBMode = BU23 diff --git a/postgres/parser/interval/interval.go b/postgres/parser/interval/interval.go deleted file mode 100644 index c8fa0c7e36..0000000000 --- a/postgres/parser/interval/interval.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright ©2012 The bíogo Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Portions of this file are additionally subject to the following -// license and copyright. -// -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// This code originated in the github.com/biogo/store/interval package. - -// Package interval provides two implementations for an interval tree. One is -// based on an augmented Left-Leaning Red Black tree. The other is based on an -// augmented BTree. -package interval - -import ( - "bytes" - "fmt" - - "github.com/cockroachdb/errors" -) - -// ErrInvertedRange is returned if an interval is used where the start value is greater -// than the end value. -var ErrInvertedRange = errors.Newf("interval: inverted range") - -// ErrEmptyRange is returned if an interval is used where the start value is equal -// to the end value. -var ErrEmptyRange = errors.Newf("interval: empty range") - -// ErrNilRange is returned if an interval is used where both the start value and -// the end value are nil. This is a specialization of ErrEmptyRange. -var ErrNilRange = errors.Newf("interval: nil range") - -func rangeError(r Range) error { - switch r.Start.Compare(r.End) { - case 1: - return ErrInvertedRange - case 0: - if len(r.Start) == 0 && len(r.End) == 0 { - return ErrNilRange - } - return ErrEmptyRange - default: - return nil - } -} - -// A Range is a type that describes the basic characteristics of an interval. -type Range struct { - Start, End Comparable -} - -// Equal returns whether the two ranges are equal. -func (r Range) Equal(other Range) bool { - return r.Start.Equal(other.Start) && r.End.Equal(other.End) -} - -// String implements the Stringer interface. -func (r Range) String() string { - return fmt.Sprintf("{%x-%x}", r.Start, r.End) -} - -// Overlapper specifies the overlapping relationship. -type Overlapper interface { - // Overlap checks whether two ranges overlap. - Overlap(Range, Range) bool -} - -type inclusiveOverlapper struct{} - -// Overlap checks where a and b overlap in the inclusive way. -func (overlapper inclusiveOverlapper) Overlap(a Range, b Range) bool { - return overlapsInclusive(a, b) -} - -func overlapsInclusive(a Range, b Range) bool { - return a.Start.Compare(b.End) <= 0 && b.Start.Compare(a.End) <= 0 -} - -// InclusiveOverlapper defines overlapping as a pair of ranges that share a segment of the keyspace -// in the inclusive way. "inclusive" means that both start and end keys treated as inclusive values. -var InclusiveOverlapper = inclusiveOverlapper{} - -type exclusiveOverlapper struct{} - -// Overlap checks where a and b overlap in the exclusive way. -func (overlapper exclusiveOverlapper) Overlap(a Range, b Range) bool { - return overlapsExclusive(a, b) -} - -func overlapsExclusive(a Range, b Range) bool { - return a.Start.Compare(b.End) < 0 && b.Start.Compare(a.End) < 0 -} - -// ExclusiveOverlapper defines overlapping as a pair of ranges that share a segment of the keyspace -// in the exclusive. "exclusive" means that the start keys are treated as inclusive and the end keys -// are treated as exclusive. -var ExclusiveOverlapper = exclusiveOverlapper{} - -// An Interface is a type that can be inserted into an interval tree. -type Interface interface { - Range() Range - // Returns a unique ID for the element. - // TODO(nvanbenschoten): Should this be changed to an int64? - ID() uintptr -} - -// Compare returns a value indicating the sort order relationship between a and b. The comparison is -// performed lexicographically on (a.Range().Start, a.ID()) and (b.Range().Start, b.ID()) tuples -// where Range().Start is more significant that ID(). -// -// Given c = Compare(a, b): -// -// c == -1 if (a.Range().Start, a.ID()) < (b.Range().Start, b.ID()); -// c == 0 if (a.Range().Start, a.ID()) == (b.Range().Start, b.ID()); and -// c == 1 if (a.Range().Start, a.ID()) > (b.Range().Start, b.ID()). -// -// "c == 0" is equivalent to "Equal(a, b) == true". -func Compare(a, b Interface) int { - startCmp := a.Range().Start.Compare(b.Range().Start) - if startCmp != 0 { - return startCmp - } - aID := a.ID() - bID := b.ID() - if aID < bID { - return -1 - } else if aID > bID { - return 1 - } else { - return 0 - } -} - -// Equal returns a boolean indicating whether the given Interfaces are equal to each other. If -// "Equal(a, b) == true", "a.Range().End == b.Range().End" must hold. Otherwise, the interval tree -// behavior is undefined. "Equal(a, b) == true" is equivalent to "Compare(a, b) == 0". But the -// former has measurably better performance than the latter. So Equal should be used when only -// equality state is needed. -func Equal(a, b Interface) bool { - return a.ID() == b.ID() && a.Range().Start.Equal(b.Range().Start) -} - -// A Comparable is a type that describes the ends of a Range. -type Comparable []byte - -// Compare returns a value indicating the sort order relationship between the -// receiver and the parameter. -// -// Given c = a.Compare(b): -// c == -1 if a < b; -// c == 0 if a == b; and -// c == 1 if a > b. -// -func (c Comparable) Compare(o Comparable) int { - return bytes.Compare(c, o) -} - -// Equal returns a boolean indicating if the given comparables are equal to -// each other. Note that this has measurably better performance than -// Compare() == 0, so it should be used when only equality state is needed. -func (c Comparable) Equal(o Comparable) bool { - return bytes.Equal(c, o) -} - -// An Operation is a function that operates on an Interface. If done is returned true, the -// Operation is indicating that no further work needs to be done and so the DoMatching function -// should traverse no further. -type Operation func(Interface) (done bool) - -// Tree is an interval tree. For all the functions which have a fast argument, -// fast being true means a fast operation which does not adjust node ranges. If -// fast is false, node ranges are adjusted. -type Tree interface { - // AdjustRanges fixes range fields for all nodes in the tree. This must be - // called before Get, Do or DoMatching* is used if fast insertion or deletion - // has been performed. - AdjustRanges() - // Len returns the number of intervals stored in the Tree. - Len() int - // Get returns a slice of Interfaces that overlap r in the tree. The slice is - // sorted nondecreasingly by interval start. - Get(r Range) []Interface - // GetWithOverlapper returns a slice of Interfaces that overlap r in the tree - // using the provided overlapper function. The slice is sorted nondecreasingly - // by interval start. - GetWithOverlapper(r Range, overlapper Overlapper) []Interface - // Insert inserts the Interface e into the tree. Insertions may replace an - // existing Interface which is equal to the Interface e. - Insert(e Interface, fast bool) error - // Delete deletes the Interface e if it exists in the tree. The deleted - // Interface is equal to the Interface e. - Delete(e Interface, fast bool) error - // Do performs fn on all intervals stored in the tree. The traversal is done - // in the nondecreasing order of interval start. A boolean is returned - // indicating whether the traversal was interrupted by an Operation returning - // true. If fn alters stored intervals' sort relationships, future tree - // operation behaviors are undefined. - Do(fn Operation) bool - // DoMatching performs fn on all intervals stored in the tree that overlaps r. - // The traversal is done in the nondecreasing order of interval start. A - // boolean is returned indicating whether the traversal was interrupted by an - // Operation returning true. If fn alters stored intervals' sort - // relationships, future tree operation behaviors are undefined. - DoMatching(fn Operation, r Range) bool - // Iterator creates an iterator to iterate over all intervals stored in the - // tree, in-order. - Iterator() TreeIterator - // Clear this tree. - Clear() - // Clone clones the tree, returning a copy. - Clone() Tree -} - -// TreeIterator iterates over all intervals stored in the interval tree, in-order. -type TreeIterator interface { - // Next returns the current interval stored in the interval tree and moves - // the iterator to the next interval. The method returns false if no intervals - // remain in the interval tree. - Next() (Interface, bool) -} - -// NewTree creates a new interval tree with the given overlapper function. It -// uses the augmented Left-Leaning Red Black tree implementation. -func NewTree(overlapper Overlapper) Tree { - return newLLRBTree(overlapper) -} diff --git a/postgres/parser/interval/llrb_based_interval.go b/postgres/parser/interval/llrb_based_interval.go deleted file mode 100644 index 9c38e52d9a..0000000000 --- a/postgres/parser/interval/llrb_based_interval.go +++ /dev/null @@ -1,690 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright ©2012 The bíogo Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in licenses/BSD-biogo.txt. - -// Portions of this file are additionally subject to the following -// license and copyright. -// -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// This code originated in the github.com/biogo/store/interval package. - -package interval - -import "github.com/biogo/store/llrb" - -// Operation LLRBMode of the underlying LLRB tree. -const ( - TD234 = iota - BU23 -) - -func init() { - if LLRBMode != TD234 && LLRBMode != BU23 { - panic("interval: unknown LLRBMode") - } -} - -// A Node represents a node in a tree. -type llrbNode struct { - Elem Interface - Range Range - Left, Right *llrbNode - Color llrb.Color -} - -var _ Tree = (*llrbTree)(nil) - -// llrbTree an interval tree based on an augmented Left-Leaning Red Black tree. -type llrbTree struct { - Root *llrbNode // root node of the tree. - Count int // number of elements stored. - Overlapper Overlapper -} - -// newLLRBTree creates a new interval tree with the given overlapper function. -func newLLRBTree(overlapper Overlapper) *llrbTree { - return &llrbTree{Overlapper: overlapper} -} - -// Helper methods - -// color returns the effect color of a Node. A nil node returns black. -func (n *llrbNode) color() llrb.Color { - if n == nil { - return llrb.Black - } - return n.Color -} - -// maxRange returns the furthest right position held by the subtree -// rooted at root, assuming that the left and right nodes have correct -// range extents. -func maxRange(root, left, right *llrbNode) Comparable { - end := root.Elem.Range().End - if left != nil && left.Range.End.Compare(end) > 0 { - end = left.Range.End - } - if right != nil && right.Range.End.Compare(end) > 0 { - end = right.Range.End - } - return end -} - -// (a,c)b -rotL-> ((a,)b,)c -func (n *llrbNode) rotateLeft() (root *llrbNode) { - // Assumes: n has a right child. - root = n.Right - n.Right = root.Left - root.Left = n - root.Color = n.Color - n.Color = llrb.Red - - root.Left.Range.End = maxRange(root.Left, root.Left.Left, root.Left.Right) - if root.Left == nil { - root.Range.Start = root.Elem.Range().Start - } else { - root.Range.Start = root.Left.Range.Start - } - root.Range.End = maxRange(root, root.Left, root.Right) - - return -} - -// (a,c)b -rotR-> (,(,c)b)a -func (n *llrbNode) rotateRight() (root *llrbNode) { - // Assumes: n has a left child. - root = n.Left - n.Left = root.Right - root.Right = n - root.Color = n.Color - n.Color = llrb.Red - - if root.Right.Left == nil { - root.Right.Range.Start = root.Right.Elem.Range().Start - } else { - root.Right.Range.Start = root.Right.Left.Range.Start - } - root.Right.Range.End = maxRange(root.Right, root.Right.Left, root.Right.Right) - root.Range.End = maxRange(root, root.Left, root.Right) - - return -} - -// (aR,cR)bB -flipC-> (aB,cB)bR | (aB,cB)bR -flipC-> (aR,cR)bB -func (n *llrbNode) flipColors() { - // Assumes: n has two children. - n.Color = !n.Color - n.Left.Color = !n.Left.Color - n.Right.Color = !n.Right.Color -} - -// fixUp ensures that black link balance is correct, that red nodes lean left, -// and that 4 nodes are split in the case of BU23 and properly balanced in TD234. -func (n *llrbNode) fixUp(fast bool) *llrbNode { - if !fast { - n.adjustRange() - } - if n.Right.color() == llrb.Red { - if LLRBMode == TD234 && n.Right.Left.color() == llrb.Red { - n.Right = n.Right.rotateRight() - } - n = n.rotateLeft() - } - if n.Left.color() == llrb.Red && n.Left.Left.color() == llrb.Red { - n = n.rotateRight() - } - if LLRBMode == BU23 && n.Left.color() == llrb.Red && n.Right.color() == llrb.Red { - n.flipColors() - } - - return n -} - -// adjustRange sets the Range to the maximum extent of the children's Range -// spans and the node's Elem span. -func (n *llrbNode) adjustRange() { - if n.Left == nil { - n.Range.Start = n.Elem.Range().Start - } else { - n.Range.Start = n.Left.Range.Start - } - n.Range.End = maxRange(n, n.Left, n.Right) -} - -func (n *llrbNode) moveRedLeft() *llrbNode { - n.flipColors() - if n.Right.Left.color() == llrb.Red { - n.Right = n.Right.rotateRight() - n = n.rotateLeft() - n.flipColors() - if LLRBMode == TD234 && n.Right.Right.color() == llrb.Red { - n.Right = n.Right.rotateLeft() - } - } - return n -} - -func (n *llrbNode) moveRedRight() *llrbNode { - n.flipColors() - if n.Left.Left.color() == llrb.Red { - n = n.rotateRight() - n.flipColors() - } - return n -} - -func (t *llrbTree) Len() int { - return t.Count -} - -func (t *llrbTree) Get(r Range) (o []Interface) { - return t.GetWithOverlapper(r, t.Overlapper) -} - -func (t *llrbTree) GetWithOverlapper(r Range, overlapper Overlapper) (o []Interface) { - if t.Root != nil && overlapper.Overlap(r, t.Root.Range) { - t.Root.doMatch(func(e Interface) (done bool) { o = append(o, e); return }, r, overlapper.Overlap) - } - return -} - -func (t *llrbTree) AdjustRanges() { - if t.Root == nil { - return - } - t.Root.adjustRanges() -} - -func (n *llrbNode) adjustRanges() { - if n.Left != nil { - n.Left.adjustRanges() - } - if n.Right != nil { - n.Right.adjustRanges() - } - n.adjustRange() -} - -func (t *llrbTree) Insert(e Interface, fast bool) (err error) { - r := e.Range() - if err := rangeError(r); err != nil { - return err - } - var d int - t.Root, d = t.Root.insert(e, r.Start, e.ID(), fast) - t.Count += d - t.Root.Color = llrb.Black - return -} - -func (n *llrbNode) insert( - e Interface, min Comparable, id uintptr, fast bool, -) (root *llrbNode, d int) { - if n == nil { - return &llrbNode{Elem: e, Range: e.Range()}, 1 - } else if n.Elem == nil { - n.Elem = e - if !fast { - n.adjustRange() - } - return n, 1 - } - - if LLRBMode == TD234 { - if n.Left.color() == llrb.Red && n.Right.color() == llrb.Red { - n.flipColors() - } - } - - switch c := min.Compare(n.Elem.Range().Start); { - case c == 0: - switch eid := n.Elem.ID(); { - case id == eid: - n.Elem = e - if !fast { - n.Range.End = e.Range().End - } - case id < eid: - n.Left, d = n.Left.insert(e, min, id, fast) - default: - n.Right, d = n.Right.insert(e, min, id, fast) - } - case c < 0: - n.Left, d = n.Left.insert(e, min, id, fast) - default: - n.Right, d = n.Right.insert(e, min, id, fast) - } - - if n.Right.color() == llrb.Red && n.Left.color() == llrb.Black { - n = n.rotateLeft() - } - if n.Left.color() == llrb.Red && n.Left.Left.color() == llrb.Red { - n = n.rotateRight() - } - - if LLRBMode == BU23 { - if n.Left.color() == llrb.Red && n.Right.color() == llrb.Red { - n.flipColors() - } - } - - if !fast { - n.adjustRange() - } - root = n - - return -} - -var _ = (*llrbTree)(nil).DeleteMin - -// DeleteMin deletes the leftmost interval. -func (t *llrbTree) DeleteMin(fast bool) { - if t.Root == nil { - return - } - var d int - t.Root, d = t.Root.deleteMin(fast) - t.Count += d - if t.Root == nil { - return - } - t.Root.Color = llrb.Black -} - -func (n *llrbNode) deleteMin(fast bool) (root *llrbNode, d int) { - if n.Left == nil { - return nil, -1 - } - if n.Left.color() == llrb.Black && n.Left.Left.color() == llrb.Black { - n = n.moveRedLeft() - } - n.Left, d = n.Left.deleteMin(fast) - if n.Left == nil { - n.Range.Start = n.Elem.Range().Start - } - - root = n.fixUp(fast) - - return -} - -var _ = (*llrbTree)(nil).DeleteMax - -// DeleteMax deletes the rightmost interval. -func (t *llrbTree) DeleteMax(fast bool) { - if t.Root == nil { - return - } - var d int - t.Root, d = t.Root.deleteMax(fast) - t.Count += d - if t.Root == nil { - return - } - t.Root.Color = llrb.Black -} - -func (n *llrbNode) deleteMax(fast bool) (root *llrbNode, d int) { - if n.Left != nil && n.Left.color() == llrb.Red { - n = n.rotateRight() - } - if n.Right == nil { - return nil, -1 - } - if n.Right.color() == llrb.Black && n.Right.Left.color() == llrb.Black { - n = n.moveRedRight() - } - n.Right, d = n.Right.deleteMax(fast) - if n.Right == nil { - n.Range.End = n.Elem.Range().End - } - - root = n.fixUp(fast) - - return -} - -func (t *llrbTree) Delete(e Interface, fast bool) (err error) { - r := e.Range() - if err := rangeError(r); err != nil { - return err - } - if t.Root == nil || !t.Overlapper.Overlap(r, t.Root.Range) { - return - } - var d int - t.Root, d = t.Root.delete(r.Start, e.ID(), fast) - t.Count += d - if t.Root == nil { - return - } - t.Root.Color = llrb.Black - return -} - -func (n *llrbNode) delete(min Comparable, id uintptr, fast bool) (root *llrbNode, d int) { - if p := min.Compare(n.Elem.Range().Start); p < 0 || (p == 0 && id < n.Elem.ID()) { - if n.Left != nil { - if n.Left.color() == llrb.Black && n.Left.Left.color() == llrb.Black { - n = n.moveRedLeft() - } - n.Left, d = n.Left.delete(min, id, fast) - if n.Left == nil { - n.Range.Start = n.Elem.Range().Start - } - } - } else { - if n.Left.color() == llrb.Red { - n = n.rotateRight() - } - if n.Right == nil && id == n.Elem.ID() { - return nil, -1 - } - if n.Right != nil { - if n.Right.color() == llrb.Black && n.Right.Left.color() == llrb.Black { - n = n.moveRedRight() - } - if id == n.Elem.ID() { - n.Elem = n.Right.min().Elem - n.Right, d = n.Right.deleteMin(fast) - } else { - n.Right, d = n.Right.delete(min, id, fast) - } - if n.Right == nil { - n.Range.End = n.Elem.Range().End - } - } - } - - root = n.fixUp(fast) - - return -} - -var _ = (*llrbTree)(nil).Min - -// Min returns the leftmost interval stored in the tree. -func (t *llrbTree) Min() Interface { - if t.Root == nil { - return nil - } - return t.Root.min().Elem -} - -func (n *llrbNode) min() *llrbNode { - for ; n.Left != nil; n = n.Left { - } - return n -} - -var _ = (*llrbTree)(nil).Max - -// Max returns the rightmost interval stored in the tree. -func (t *llrbTree) Max() Interface { - if t.Root == nil { - return nil - } - return t.Root.max().Elem -} - -func (n *llrbNode) max() *llrbNode { - for ; n.Right != nil; n = n.Right { - } - return n -} - -var _ = (*llrbTree)(nil).Floor - -// Floor returns the largest value equal to or less than the query q according to -// q.Start.Compare(), with ties broken by comparison of ID() values. -func (t *llrbTree) Floor(q Interface) (o Interface, err error) { - if t.Root == nil { - return - } - n := t.Root.floor(q.Range().Start, q.ID()) - if n == nil { - return - } - return n.Elem, nil -} - -func (n *llrbNode) floor(m Comparable, id uintptr) *llrbNode { - if n == nil { - return nil - } - switch c := m.Compare(n.Elem.Range().Start); { - case c == 0: - switch eid := n.Elem.ID(); { - case id == eid: - return n - case id < eid: - return n.Left.floor(m, id) - default: - if r := n.Right.floor(m, id); r != nil { - return r - } - } - case c < 0: - return n.Left.floor(m, id) - default: - if r := n.Right.floor(m, id); r != nil { - return r - } - } - return n -} - -var _ = (*llrbTree)(nil).Ceil - -// Ceil returns the smallest value equal to or greater than the query q according to -// q.Start.Compare(), with ties broken by comparison of ID() values. -func (t *llrbTree) Ceil(q Interface) (o Interface, err error) { - if t.Root == nil { - return - } - n := t.Root.ceil(q.Range().Start, q.ID()) - if n == nil { - return - } - return n.Elem, nil -} - -func (n *llrbNode) ceil(m Comparable, id uintptr) *llrbNode { - if n == nil { - return nil - } - switch c := m.Compare(n.Elem.Range().Start); { - case c == 0: - switch eid := n.Elem.ID(); { - case id == eid: - return n - case id > eid: - return n.Right.ceil(m, id) - default: - if l := n.Left.ceil(m, id); l != nil { - return l - } - } - case c > 0: - return n.Right.ceil(m, id) - default: - if l := n.Left.ceil(m, id); l != nil { - return l - } - } - return n -} - -func (t *llrbTree) Do(fn Operation) bool { - if t.Root == nil { - return false - } - return t.Root.do(fn) -} - -func (n *llrbNode) do(fn Operation) (done bool) { - if n.Left != nil { - done = n.Left.do(fn) - if done { - return - } - } - done = fn(n.Elem) - if done { - return - } - if n.Right != nil { - done = n.Right.do(fn) - } - return -} - -var _ = (*llrbTree)(nil).DoReverse - -// DoReverse performs fn on all intervals stored in the tree, but in reverse of sort order. A boolean -// is returned indicating whether the Do traversal was interrupted by an Operation returning true. -// If fn alters stored intervals' sort relationships, future tree operation behaviors are undefined. -func (t *llrbTree) DoReverse(fn Operation) bool { - if t.Root == nil { - return false - } - return t.Root.doReverse(fn) -} - -func (n *llrbNode) doReverse(fn Operation) (done bool) { - if n.Right != nil { - done = n.Right.doReverse(fn) - if done { - return - } - } - done = fn(n.Elem) - if done { - return - } - if n.Left != nil { - done = n.Left.doReverse(fn) - } - return -} - -var _ = (*llrbTree)(nil).DoMatchingReverse - -func (t *llrbTree) DoMatching(fn Operation, r Range) bool { - if t.Root != nil && t.Overlapper.Overlap(r, t.Root.Range) { - return t.Root.doMatch(fn, r, t.Overlapper.Overlap) - } - return false -} - -func (n *llrbNode) doMatch(fn Operation, r Range, overlaps func(Range, Range) bool) (done bool) { - if n.Left != nil && overlaps(r, n.Left.Range) { - done = n.Left.doMatch(fn, r, overlaps) - if done { - return - } - } - if overlaps(r, n.Elem.Range()) { - done = fn(n.Elem) - if done { - return - } - } - if n.Right != nil && overlaps(r, n.Right.Range) { - done = n.Right.doMatch(fn, r, overlaps) - } - return -} - -var _ = (*llrbTree)(nil).DoMatchingReverse - -// DoMatchingReverse performs fn on all intervals stored in the tree that match r according to -// t.Overlapper, with Overlapper() used to guide tree traversal, so DoMatching() will outperform -// Do() with a called conditional function if the condition is based on sort order, but can not -// be reliably used if the condition is independent of sort order. A boolean is returned indicating -// whether the Do traversal was interrupted by an Operation returning true. If fn alters stored -// intervals' sort relationships, future tree operation behaviors are undefined. -func (t *llrbTree) DoMatchingReverse(fn Operation, r Range) bool { - if t.Root != nil && t.Overlapper.Overlap(r, t.Root.Range) { - return t.Root.doMatchReverse(fn, r, t.Overlapper.Overlap) - } - return false -} - -func (n *llrbNode) doMatchReverse( - fn Operation, r Range, overlaps func(Range, Range) bool, -) (done bool) { - if n.Right != nil && overlaps(r, n.Right.Range) { - done = n.Right.doMatchReverse(fn, r, overlaps) - if done { - return - } - } - if overlaps(r, n.Elem.Range()) { - done = fn(n.Elem) - if done { - return - } - } - if n.Left != nil && overlaps(r, n.Left.Range) { - done = n.Left.doMatchReverse(fn, r, overlaps) - } - return -} - -type llrbTreeIterator struct { - stack []*llrbNode -} - -func (ti *llrbTreeIterator) Next() (i Interface, ok bool) { - if len(ti.stack) == 0 { - return nil, false - } - n := ti.stack[len(ti.stack)-1] - ti.stack = ti.stack[:len(ti.stack)-1] - for r := n.Right; r != nil; r = r.Left { - ti.stack = append(ti.stack, r) - } - return n.Elem, true -} - -func (t *llrbTree) Iterator() TreeIterator { - var ti llrbTreeIterator - for n := t.Root; n != nil; n = n.Left { - ti.stack = append(ti.stack, n) - } - return &ti -} - -func (t *llrbTree) Clear() { - t.Root = nil - t.Count = 0 -} - -func (t *llrbTree) Clone() Tree { - panic("unimplemented") -} diff --git a/postgres/parser/interval/range_group.go b/postgres/parser/interval/range_group.go deleted file mode 100644 index 69f153310d..0000000000 --- a/postgres/parser/interval/range_group.go +++ /dev/null @@ -1,833 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package interval - -import ( - "bytes" - "container/list" - "fmt" -) - -// RangeGroup represents a set of possibly disjointed Ranges. The -// interface exposes methods to manipulate the group by adding and -// subtracting Ranges. All methods requiring a Range will panic -// if the provided range is inverted or empty. -// -// One use case of the interface is to add ranges to the group and -// observe whether the addition increases the size of the group or -// not, indicating whether the new range's interval is redundant, or -// if it is needed for the full composition of the group. Because -// the RangeGroup builds as more ranges are added, insertion order of -// the ranges is critical. For instance, if two identical ranges are -// added, only the first to be added with Add will return true, as it -// will be the only one to expand the group. -// -// Another use case of the interface is to add and subtract ranges as -// needed to the group, allowing the internals of the implementation -// to coalesce and split ranges when needed to factor the group to -// its minimum number of disjoint ranges. -type RangeGroup interface { - // Add will attempt to add the provided Range to the RangeGroup, - // returning whether the addition increased the range of the group - // or not. - Add(Range) bool - // Sub will attempt to remove the provided Range from the RangeGroup, - // returning whether the subtraction reduced the range of the group - // or not. - Sub(Range) bool - // Clear clears all ranges from the RangeGroup, resetting it to be - // used again. - Clear() - // Overlaps returns whether the provided Range is partially contained - // within the group of Ranges in the RangeGroup. - Overlaps(Range) bool - // Encloses returns whether the provided Range is fully contained - // within the group of Ranges in the RangeGroup. - Encloses(Range) bool - // ForEach calls the provided function with each Range stored in - // the group. An error is returned indicating whether the callback - // function saw an error, whereupon the Range iteration will halt - // (potentially prematurely) and the error will be returned from ForEach - // itself. If no error is returned from the callback, the method - // will visit all Ranges in the group before returning a nil error. - ForEach(func(Range) error) error - // Iterator returns an iterator to visit each Range stored in the - // group, in-order. It is not safe to mutate the RangeGroup while - // iteration is being performed. - Iterator() RangeGroupIterator - // Len returns the number of Ranges currently within the RangeGroup. - // This will always be equal to or less than the number of ranges added, - // as ranges that overlap will merge to produce a single larger range. - Len() int - fmt.Stringer -} - -// RangeGroupIterator is an iterator that walks in-order over a RangeGroup. -type RangeGroupIterator interface { - // Next returns the next Range in the RangeGroup. It returns false - // if there are no more Ranges. - Next() (Range, bool) -} - -const rangeListNodeBucketSize = 8 - -type rangeListNode struct { - slots [rangeListNodeBucketSize]Range // ordered, non-overlapping - len int -} - -func newRangeListNodeWithRange(r Range) *rangeListNode { - var n rangeListNode - n.push(r) - return &n -} - -func (n *rangeListNode) push(r Range) { - n.slots[n.len] = r - n.len++ -} - -func (n *rangeListNode) full() bool { - return n.len == len(n.slots) -} - -func (n *rangeListNode) min() Comparable { - return n.slots[0].Start -} - -func (n *rangeListNode) max() Comparable { - return n.slots[n.len-1].End -} - -// findIdx finds the upper-bound slot index that the provided range should fit -// in the rangeListNode. It also returns whether the slot is currently occupied -// by an overlapping range. -func (n *rangeListNode) findIdx(r Range, inclusive bool) (int, bool) { - overlapFn := overlapsExclusive - passedCmp := 0 - if inclusive { - overlapFn = overlapsInclusive - passedCmp = -1 - } - for i, nr := range n.slots[:n.len] { - switch { - case overlapFn(nr, r): - return i, true - case r.End.Compare(nr.Start) <= passedCmp: - // Past where overlapping ranges would be. - return i, false - } - } - return n.len, false -} - -// rangeList is an implementation of a RangeGroup using a bucketted linked list -// to sequentially order non-overlapping ranges. -// -// rangeList is not safe for concurrent use by multiple goroutines. -type rangeList struct { - ll list.List - len int -} - -// NewRangeList constructs a linked-list backed RangeGroup. -func NewRangeList() RangeGroup { - var rl rangeList - rl.ll.Init() - rl.ll.PushFront(&rangeListNode{}) - return &rl -} - -// findNode returns the upper-bound node that the range would be bucketted in, -// along with that node's previous element. It also returns whether the range -// overlaps with the bounds of the node. -func (rl *rangeList) findNode(r Range, inclusive bool) (prev, cur *list.Element, inCur bool) { - if err := rangeError(r); err != nil { - panic(err) - } - reachedCmp := 1 - passedCmp := 0 - if inclusive { - reachedCmp = 0 - passedCmp = -1 - } - for e := rl.ll.Front(); e != nil; e = e.Next() { - n := e.Value.(*rangeListNode) - if n.len == 0 { - // The node is empty. This must be the last node in the list. - return prev, e, false - } - // Check if the range starts at a value less than (or equal to, for - // inclusive) the maximum value in this node. This is what we want. - if n.max().Compare(r.Start) >= reachedCmp { - // Determine whether the range overlap the node's bounds. - inCur := n.min().Compare(r.End) <= passedCmp - return prev, e, inCur - } - prev = e - } - return prev, nil, false -} - -// insertAtIdx inserts the provided range at the specified index in the node. -// It performs any necessary slot movement to keep the ranges ordered. -// -// Note: e.Value is expected to be n, but we want to avoid repeated type -// assertions so both are taken as arguments. -func (rl *rangeList) insertAtIdx(e *list.Element, n *rangeListNode, r Range, i int) { - if n.full() { - // If the current node is full, we're going to need to shift off a range - // from one of the slots into a different node. If i is not pointing - // past the end of the range, we need to shift off the range currently - // in the last slot in this node. If i is pointing past the end of the - // range, we can just shift the new range to a different node without - // making any changes to this node. - toShift := n.slots[n.len-1] - noLocalChanges := false - if i == n.len { - toShift = r - - // We're going to add range r to a different node, so there will be - // no local changes to this node. - noLocalChanges = true - } - - // Check if the next node has room. Note that we only call insertAtIdx - // recursively on the next node if it is not full. We don't want to - // shift recursively all the way down the list. Instead, we'll just pay - // the constant cost below of inserting a fresh node in-between the - // current and next node. - next := e.Next() - insertedInNext := false - if next != nil { - nextN := next.Value.(*rangeListNode) - if !nextN.full() { - rl.insertAtIdx(next, nextN, toShift, 0) - insertedInNext = true - } - } - if !insertedInNext { - newN := newRangeListNodeWithRange(toShift) - rl.ll.InsertAfter(newN, e) - rl.len++ - } - - if noLocalChanges { - return - } - } else { - n.len++ - rl.len++ - } - // Shift all others over and copy the new range in. Because of - // the n.full check, we know that we'll have at least one free - // slot open at the end. - copy(n.slots[i+1:n.len], n.slots[i:n.len-1]) - n.slots[i] = r -} - -// Add implements RangeGroup. It iterates over the current ranges in the -// rangeList to find which overlap with the new range. If there is no -// overlap, the new range will be added and the function will return true. -// If there is some overlap, the function will return true if the new -// range increases the range of the rangeList, in which case it will be -// added to the list, and false if it does not increase the range, in which -// case it won't be added. If the range is added, the function will also attempt -// to merge any ranges within the list that now overlap. -func (rl *rangeList) Add(r Range) bool { - prev, cur, inCur := rl.findNode(r, true /* inclusive */) - - var prevN *rangeListNode - if prev != nil { - prevN = prev.Value.(*rangeListNode) - } - if !inCur && prevN != nil && !prevN.full() { - // There is a previous node. Add the range to the end of that node. - prevN.push(r) - rl.len++ - return true - } - - if cur != nil { - n := cur.Value.(*rangeListNode) - i, ok := n.findIdx(r, true /* inclusive */) - if !ok { - // The range can't be merged with any existing ranges, but should - // instead be inserted at slot i. This may force us to shift over - // other slots. - rl.insertAtIdx(cur, n, r, i) - return true - } - - // If a current range fully contains the new range, no need to add it. - nr := n.slots[i] - if contains(nr, r) { - return false - } - - // Merge as many ranges as possible and replace old range. All merges - // will be made into n.slots[i] because adding a range to the RangeGroup - // can result in only at most one group of ranges being merged. - // - // For example: - // existing ranges : ---- ----- ---- ---- ---- - // new range : -------------- - // resulting ranges: ---- ------------------- ---- - // - // In this example, n.slots[i] is the first existing range that overlaps - // with the new range. - newR := merge(nr, r) - n.slots[i] = newR - - // Each iteration attempts to merge all of the ranges in a rangeListNode. - mergeElem := cur - origNode := true - for { - mergeN := mergeElem.Value.(*rangeListNode) - origLen := mergeN.len - - // mergeState is the slot index to begin merging from - mergeStart := 0 - if origNode { - mergeStart = i + 1 - } - - // Each iteration attempts to merge a single range into the current - // merge batch. - j := mergeStart - for ; j < origLen; j++ { - mergeR := mergeN.slots[j] - - if overlapsInclusive(newR, mergeR) { - newR = merge(newR, mergeR) - n.slots[i] = newR - mergeN.len-- - rl.len-- - } else { - // If the ranges don't overlap, that means index j and up - // are still needed in the current node. Shift these over. - copy(mergeN.slots[mergeStart:mergeStart+origLen-j], mergeN.slots[j:origLen]) - return true - } - } - - // If we didn't break, that means that all of the slots including - // and after mergeStart in the current node were merged. Continue - // onto the next node. - nextE := mergeElem.Next() - if !origNode { - // If this is not the current node, we can delete it - // completely. - rl.ll.Remove(mergeElem) - } - if nextE == nil { - return true - } - mergeElem = nextE - origNode = false - } - } - - // The new range couldn't be added to the previous or the current node. - // We'll have to create a new node for the range. - n := newRangeListNodeWithRange(r) - if prevN != nil { - // There is a previous node and it is full. - rl.ll.InsertAfter(n, prev) - } else { - // There is no previous node. Add the range to a new node in the front - // of the list. - rl.ll.PushFront(n) - } - rl.len++ - return true -} - -// Sub implements RangeGroup. It iterates over the current ranges in the -// rangeList to find which overlap with the range to subtract. For all -// ranges that overlap with the provided range, the overlapping segment of -// the range is removed. If the provided range fully contains a range in -// the rangeList, the range in the rangeList will be removed. The method -// returns whether the subtraction resulted in any decrease to the size -// of the RangeGroup. -func (rl *rangeList) Sub(r Range) bool { - _, cur, inCur := rl.findNode(r, false /* inclusive */) - if !inCur { - // The range does not overlap any nodes. Nothing to do. - return false - } - - n := cur.Value.(*rangeListNode) - i, ok := n.findIdx(r, false /* inclusive */) - if !ok { - // The range does not overlap any ranges in the node. Nothing to do. - return false - } - - for { - nr := n.slots[i] - if !overlapsExclusive(nr, r) { - // The range does not overlap nr so stop trying to subtract. The - // findIdx check above guarantees that this will never be the case - // for the first iteration of this loop. - return true - } - - sCmp := nr.Start.Compare(r.Start) - eCmp := nr.End.Compare(r.End) - delStart := sCmp >= 0 - delEnd := eCmp <= 0 - - switch { - case delStart && delEnd: - // Remove the entire range. - n.len-- - rl.len-- - if n.len == 0 { - // Move to the next node, removing the current node as long as - // it's not the only one in the list. - i = 0 - next := cur.Next() - if rl.len > 0 { - rl.ll.Remove(cur) - } - if next == nil { - return true - } - cur = next - n = cur.Value.(*rangeListNode) - continue - } else { - // Replace the current Range. - copy(n.slots[i:n.len], n.slots[i+1:n.len+1]) - } - // Don't increment i. - case delStart: - // Remove the start of the range by truncating. Can return after. - n.slots[i].Start = r.End - return true - case delEnd: - // Remove the end of the range by truncating. - n.slots[i].End = r.Start - i++ - default: - // Remove the middle of the range by splitting and truncating. - oldEnd := nr.End - n.slots[i].End = r.Start - - // Create right side of split. Can return after. - rSplit := Range{Start: r.End, End: oldEnd} - rl.insertAtIdx(cur, n, rSplit, i+1) - return true - } - - // Move to the next node, if necessary. - if i >= n.len { - i = 0 - cur = cur.Next() - if cur == nil { - return true - } - n = cur.Value.(*rangeListNode) - } - } -} - -// Clear implements RangeGroup. It clears all ranges from the -// rangeList. -func (rl *rangeList) Clear() { - // Empty the first node, but keep it in the list. - f := rl.ll.Front().Value.(*rangeListNode) - *f = rangeListNode{} - - // If the list has more than one node in it, remove all but the first. - if rl.ll.Len() > 1 { - rl.ll.Init() - rl.ll.PushBack(f) - } - rl.len = 0 -} - -// Overlaps implements RangeGroup. It returns whether the provided -// Range is partially contained within the group of Ranges in the rangeList. -func (rl *rangeList) Overlaps(r Range) bool { - if _, cur, inCur := rl.findNode(r, false /* inclusive */); inCur { - n := cur.Value.(*rangeListNode) - if _, ok := n.findIdx(r, false /* inclusive */); ok { - return true - } - } - return false -} - -// Encloses implements RangeGroup. It returns whether the provided -// Range is fully contained within the group of Ranges in the rangeList. -func (rl *rangeList) Encloses(r Range) bool { - if _, cur, inCur := rl.findNode(r, false /* inclusive */); inCur { - n := cur.Value.(*rangeListNode) - if i, ok := n.findIdx(r, false /* inclusive */); ok { - return contains(n.slots[i], r) - } - } - return false -} - -// ForEach implements RangeGroup. It calls the provided function f -// with each Range stored in the rangeList. -func (rl *rangeList) ForEach(f func(Range) error) error { - it := rangeListIterator{e: rl.ll.Front()} - for r, ok := it.Next(); ok; r, ok = it.Next() { - if err := f(r); err != nil { - return err - } - } - return nil -} - -// rangeListIterator is an in-order iterator operating over a rangeList. -type rangeListIterator struct { - e *list.Element - idx int // next slot index -} - -// Next implements RangeGroupIterator. It returns the next Range in the -// rangeList, or false. -func (rli *rangeListIterator) Next() (r Range, ok bool) { - if rli.e != nil { - n := rli.e.Value.(*rangeListNode) - - // Get current index, return if invalid. - curIdx := rli.idx - if curIdx >= n.len { - return Range{}, false - } - - // Move index and Element pointer forwards. - rli.idx = curIdx + 1 - if rli.idx >= n.len { - rli.idx = 0 - rli.e = rli.e.Next() - } - - return n.slots[curIdx], true - } - return Range{}, false -} - -// Iterator implements RangeGroup. It returns an iterator to iterate over -// the group of ranges. -func (rl *rangeList) Iterator() RangeGroupIterator { - return &rangeListIterator{e: rl.ll.Front()} -} - -// Len implements RangeGroup. It returns the number of ranges in -// the rangeList. -func (rl *rangeList) Len() int { - return rl.len -} - -func (rl *rangeList) String() string { - return rgString(rl) -} - -// rangeTree is an implementation of a RangeGroup using an interval -// tree to efficiently store and search for non-overlapping ranges. -// -// rangeTree is not safe for concurrent use by multiple goroutines. -type rangeTree struct { - t Tree - idCount uintptr -} - -// NewRangeTree constructs an interval tree backed RangeGroup. -func NewRangeTree() RangeGroup { - return &rangeTree{ - t: NewTree(InclusiveOverlapper), - } -} - -// rangeKey implements Interface and can be inserted into a Tree. It -// provides uniqueness as well as a key interval. -type rangeKey struct { - r Range - id uintptr -} - -var _ Interface = rangeKey{} - -// makeKey creates a new rangeKey defined by the provided range. -func (rt *rangeTree) makeKey(r Range) rangeKey { - rt.idCount++ - return rangeKey{ - r: r, - id: rt.idCount, - } -} - -// Range implements Interface. -func (rk rangeKey) Range() Range { - return rk.r -} - -// ID implements Interface. -func (rk rangeKey) ID() uintptr { - return rk.id -} - -func (rk rangeKey) String() string { - return fmt.Sprintf("%d: %q-%q", rk.id, rk.r.Start, rk.r.End) -} - -// Add implements RangeGroup. It first uses the interval tree to lookup -// the current ranges which overlap with the new range. If there is no -// overlap, the new range will be added and the function will return true. -// If there is some overlap, the function will return true if the new -// range increases the range of the rangeTree, in which case it will be -// added to the tree, and false if it does not increase the range, in which -// case it won't be added. If the range is added, the function will also attempt -// to merge any ranges within the tree that now overlap. -func (rt *rangeTree) Add(r Range) bool { - if err := rangeError(r); err != nil { - panic(err) - } - overlaps := rt.t.Get(r) - if len(overlaps) == 0 { - key := rt.makeKey(r) - if err := rt.t.Insert(&key, false /* fast */); err != nil { - panic(err) - } - return true - } - first := overlaps[0].(*rangeKey) - - // If a current range fully contains the new range, no - // need to add it. - if contains(first.r, r) { - return false - } - - // Merge as many ranges as possible, and replace old range. - first.r = merge(first.r, r) - for _, o := range overlaps[1:] { - other := o.(*rangeKey) - first.r = merge(first.r, other.r) - if err := rt.t.Delete(o, true /* fast */); err != nil { - panic(err) - } - } - rt.t.AdjustRanges() - return true -} - -// Sub implements RangeGroup. It first uses the interval tree to lookup -// the current ranges which overlap with the range to subtract. For all -// ranges that overlap with the provided range, the overlapping segment of -// the range is removed. If the provided range fully contains a range in -// the rangeTree, the range in the rangeTree will be removed. The method -// returns whether the subtraction resulted in any decrease to the size -// of the RangeGroup. -func (rt *rangeTree) Sub(r Range) bool { - if err := rangeError(r); err != nil { - panic(err) - } - overlaps := rt.t.GetWithOverlapper(r, ExclusiveOverlapper) - if len(overlaps) == 0 { - return false - } - - for _, o := range overlaps { - rk := o.(*rangeKey) - sCmp := rk.r.Start.Compare(r.Start) - eCmp := rk.r.End.Compare(r.End) - - delStart := sCmp >= 0 - delEnd := eCmp <= 0 - - switch { - case delStart && delEnd: - // Remove the entire range. - if err := rt.t.Delete(o, true /* fast */); err != nil { - panic(err) - } - case delStart: - // Remove the start of the range by truncating. - rk.r.Start = r.End - case delEnd: - // Remove the end of the range by truncating. - rk.r.End = r.Start - default: - // Remove the middle of the range by splitting. - oldEnd := rk.r.End - rk.r.End = r.Start - - rSplit := Range{Start: r.End, End: oldEnd} - rKey := rt.makeKey(rSplit) - if err := rt.t.Insert(&rKey, true /* fast */); err != nil { - panic(err) - } - } - } - rt.t.AdjustRanges() - return true -} - -// Clear implements RangeGroup. It clears all rangeKeys from the rangeTree. -func (rt *rangeTree) Clear() { - rt.t.Clear() -} - -// Overlaps implements RangeGroup. It returns whether the provided -// Range is partially contained within the group of Ranges in the rangeTree. -func (rt *rangeTree) Overlaps(r Range) bool { - if err := rangeError(r); err != nil { - panic(err) - } - overlaps := rt.t.GetWithOverlapper(r, ExclusiveOverlapper) - return len(overlaps) > 0 -} - -// Encloses implements RangeGroup. It returns whether the provided -// Range is fully contained within the group of Ranges in the rangeTree. -func (rt *rangeTree) Encloses(r Range) bool { - if err := rangeError(r); err != nil { - panic(err) - } - overlaps := rt.t.GetWithOverlapper(r, ExclusiveOverlapper) - if len(overlaps) != 1 { - return false - } - first := overlaps[0].(*rangeKey) - return contains(first.r, r) -} - -// ForEach implements RangeGroup. It calls the provided function f -// with each Range stored in the rangeTree. -func (rt *rangeTree) ForEach(f func(Range) error) error { - var err error - rt.t.Do(func(i Interface) bool { - err = f(i.Range()) - return err != nil - }) - return err -} - -// rangeListIterator is an in-order iterator operating over a rangeTree. -type rangeTreeIterator struct { - it TreeIterator -} - -// Next implements RangeGroupIterator. It returns the next Range in the -// rangeTree, or false. -func (rti *rangeTreeIterator) Next() (r Range, ok bool) { - i, ok := rti.it.Next() - if !ok { - return Range{}, false - } - return i.Range(), true -} - -// Iterator implements RangeGroup. It returns an iterator to iterate over -// the group of ranges. -func (rt *rangeTree) Iterator() RangeGroupIterator { - return &rangeTreeIterator{it: rt.t.Iterator()} -} - -// Len implements RangeGroup. It returns the number of rangeKeys in -// the rangeTree. -func (rt *rangeTree) Len() int { - return rt.t.Len() -} - -func (rt *rangeTree) String() string { - return rgString(rt) -} - -// contains returns if the range in the out range fully contains the -// in range. -func contains(out, in Range) bool { - return in.Start.Compare(out.Start) >= 0 && out.End.Compare(in.End) >= 0 -} - -// merge merges the provided ranges together into their union range. The -// ranges must overlap or the function will not produce the correct output. -func merge(l, r Range) Range { - start := l.Start - if r.Start.Compare(start) < 0 { - start = r.Start - } - end := l.End - if r.End.Compare(end) > 0 { - end = r.End - } - return Range{Start: start, End: end} -} - -// rgString returns a string representation of the ranges in a RangeGroup. -func rgString(rg RangeGroup) string { - var buffer bytes.Buffer - buffer.WriteRune('[') - space := false - if err := rg.ForEach(func(r Range) error { - if space { - buffer.WriteRune(' ') - } - buffer.WriteString(r.String()) - space = true - return nil - }); err != nil { - panic(err) - } - buffer.WriteRune(']') - return buffer.String() -} - -// RangeGroupsOverlap determines if two RangeGroups contain any overlapping -// Ranges or if they are fully disjoint. It does so by iterating over the -// RangeGroups together and comparing subsequent ranges. -func RangeGroupsOverlap(rg1, rg2 RangeGroup) bool { - it1, it2 := rg1.Iterator(), rg2.Iterator() - r1, ok1 := it1.Next() - r2, ok2 := it2.Next() - if !ok1 || !ok2 { - return false - } - for { - // Check if the current pair of Ranges overlap. - if overlapsExclusive(r1, r2) { - return true - } - - // If not, advance the Range further behind. - var ok bool - if r1.Start.Compare(r2.Start) < 0 { - r1, ok = it1.Next() - } else { - r2, ok = it2.Next() - } - if !ok { - return false - } - } -} diff --git a/postgres/parser/interval/td234.go b/postgres/parser/interval/td234.go deleted file mode 100644 index 63c65e8b0b..0000000000 --- a/postgres/parser/interval/td234.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2019 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// Copyright ©2014 The bíogo Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in licenses/BSD-biogo.txt. - -// This code originated in the github.com/biogo/store/interval package. - -//go:build td234 -// +build td234 - -package interval - -const LLRBMode = TD234 diff --git a/postgres/parser/ipaddr/ipaddr.go b/postgres/parser/ipaddr/ipaddr.go index df259adb4e..b27ee13bc1 100644 --- a/postgres/parser/ipaddr/ipaddr.go +++ b/postgres/parser/ipaddr/ipaddr.go @@ -29,7 +29,6 @@ import ( "encoding/binary" "io" "math" - "math/rand" "net" "strconv" "strings" @@ -265,22 +264,6 @@ func ParseINet(s string, dest *IPAddr) error { return nil } -// RandIPAddr generates a random IPAddr. This includes random mask size and IP -// family. -func RandIPAddr(rng *rand.Rand) IPAddr { - var ipAddr IPAddr - if rng.Intn(2) > 0 { - ipAddr.Family = IPv4family - ipAddr.Mask = byte(rng.Intn(33)) - ipAddr.Addr = Addr(utils.FromInts(0, uint64(rng.Uint32())|IPv4mappedIPv6prefix)) - } else { - ipAddr.Family = IPv6family - ipAddr.Mask = byte(rng.Intn(129)) - ipAddr.Addr = Addr(utils.FromInts(rng.Uint64(), rng.Uint64())) - } - return ipAddr -} - // Hostmask returns the host masked IP. This is defined as the IP address bits // that are not masked. func (ipAddr *IPAddr) Hostmask() IPAddr { diff --git a/postgres/parser/json/encoded.go b/postgres/parser/json/encoded.go index d673b88075..3a685a5271 100644 --- a/postgres/parser/json/encoded.go +++ b/postgres/parser/json/encoded.go @@ -29,11 +29,10 @@ import ( "fmt" "sort" "strconv" + "sync" "unsafe" "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/syncutil" ) type jsonEncoded struct { @@ -48,7 +47,7 @@ type jsonEncoded struct { // TODO(justin): for simplicity right now we use a mutex, we could be using // an atomic CAS though. mu struct { - syncutil.RWMutex + sync.RWMutex cachedDecoded JSON } diff --git a/postgres/parser/json/json.go b/postgres/parser/json/json.go index f94223c1e8..67d9da92b7 100644 --- a/postgres/parser/json/json.go +++ b/postgres/parser/json/json.go @@ -44,7 +44,6 @@ import ( "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" "github.com/dolthub/doltgresql/postgres/parser/pgcode" "github.com/dolthub/doltgresql/postgres/parser/pgerror" - "github.com/dolthub/doltgresql/postgres/parser/unique" ) // Type represents a JSON type. @@ -781,10 +780,36 @@ func (j jsonArray) encodeInvertedIndexKeys(b []byte) ([][]byte, error) { // to emit duplicate keys from this method, as it's more expensive to // deduplicate keys via KV (which will actually write the keys) than via SQL // (just an in-memory sort and distinct). - outKeys = unique.UniquifyByteSlices(outKeys) + outKeys = UniquifyByteSlices(outKeys) return outKeys, nil } +// UniquifyByteSlices takes as input a slice of slices of bytes, and +// deduplicates them using a sort and unique. The output will not contain any +// duplicates but it will be sorted. +func UniquifyByteSlices(slices [][]byte) [][]byte { + if len(slices) == 0 { + return slices + } + // First sort: + sort.Slice(slices, func(i int, j int) bool { + return bytes.Compare(slices[i], slices[j]) < 0 + }) + // Then distinct: (wouldn't it be nice if Go had generics?) + lastUniqueIdx := 0 + for i := 1; i < len(slices); i++ { + if !bytes.Equal(slices[i], slices[lastUniqueIdx]) { + // We found a unique entry, at index i. The last unique entry in the array + // was at lastUniqueIdx, so set the entry after that one to our new unique + // entry, and bump lastUniqueIdx for the next loop iteration. + lastUniqueIdx++ + slices[lastUniqueIdx] = slices[i] + } + } + slices = slices[:lastUniqueIdx+1] + return slices +} + func (j jsonObject) encodeInvertedIndexKeys(b []byte) ([][]byte, error) { // Checking for an empty object. if len(j) == 0 { diff --git a/postgres/parser/kv/emptytxn.go b/postgres/parser/kv/emptytxn.go deleted file mode 100644 index b5d3bea9c0..0000000000 --- a/postgres/parser/kv/emptytxn.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kv - -// Txn only exists as it's used throughout the parser. This will eventually be removed. -type Txn struct{} diff --git a/postgres/parser/pgdate/zone_cache.go b/postgres/parser/pgdate/zone_cache.go index 231c162e4e..0b92ceba21 100644 --- a/postgres/parser/pgdate/zone_cache.go +++ b/postgres/parser/pgdate/zone_cache.go @@ -26,9 +26,8 @@ package pgdate import ( "fmt" + "sync" "time" - - "github.com/dolthub/doltgresql/postgres/parser/syncutil" ) // zoneCache stores the results of resolving time.Location instances. @@ -40,7 +39,7 @@ import ( // the string representation. type zoneCache struct { mu struct { - syncutil.Mutex + sync.Mutex named map[string]*zoneCacheEntry fixed map[int]*time.Location } diff --git a/postgres/parser/pgnotice/display_severity.go b/postgres/parser/pgnotice/display_severity.go deleted file mode 100644 index cd9cd3f3ae..0000000000 --- a/postgres/parser/pgnotice/display_severity.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package pgnotice - -import ( - "fmt" - "strings" -) - -// DisplaySeverity indicates the severity of a given error for the -// purposes of displaying notices. -// This corresponds to the allowed values for the `client_min_messages` -// variable in postgres. -type DisplaySeverity int - -// It is important to keep the same order here as Postgres. -// See https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-CLIENT-MIN-MESSAGES. - -const ( - // DisplaySeverityError is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityError to display. - DisplaySeverityError = iota - // DisplaySeverityWarning is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityWarning to display. - DisplaySeverityWarning - // DisplaySeverityNotice is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityNotice to display. - DisplaySeverityNotice - // DisplaySeverityLog is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityLog.g to display. - DisplaySeverityLog - // DisplaySeverityDebug1 is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityDebug1 to display. - DisplaySeverityDebug1 - // DisplaySeverityDebug2 is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityDebug2 to display. - DisplaySeverityDebug2 - // DisplaySeverityDebug3 is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityDebug3 to display. - DisplaySeverityDebug3 - // DisplaySeverityDebug4 is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityDebug4 to display. - DisplaySeverityDebug4 - // DisplaySeverityDebug5 is a DisplaySeverity value allowing all notices - // of value <= DisplaySeverityDebug5 to display. - DisplaySeverityDebug5 -) - -// ParseDisplaySeverity translates a string to a DisplaySeverity. -// Returns the severity, and a bool indicating whether the severity exists. -func ParseDisplaySeverity(k string) (DisplaySeverity, bool) { - s, ok := namesToDisplaySeverity[strings.ToLower(k)] - return s, ok -} - -func (ns DisplaySeverity) String() string { - if ns < 0 || ns > DisplaySeverity(len(noticeDisplaySeverityNames)-1) { - return fmt.Sprintf("DisplaySeverity(%d)", ns) - } - return noticeDisplaySeverityNames[ns] -} - -// noticeDisplaySeverityNames maps a DisplaySeverity into it's string representation. -var noticeDisplaySeverityNames = [...]string{ - DisplaySeverityDebug5: "debug5", - DisplaySeverityDebug4: "debug4", - DisplaySeverityDebug3: "debug3", - DisplaySeverityDebug2: "debug2", - DisplaySeverityDebug1: "debug1", - DisplaySeverityLog: "log", - DisplaySeverityNotice: "notice", - DisplaySeverityWarning: "warning", - DisplaySeverityError: "error", -} - -// namesToDisplaySeverity is the reverse mapping from string to DisplaySeverity. -var namesToDisplaySeverity = map[string]DisplaySeverity{} - -// ValidDisplaySeverities returns a list of all valid severities. -func ValidDisplaySeverities() []string { - ret := make([]string, 0, len(namesToDisplaySeverity)) - for _, s := range noticeDisplaySeverityNames { - ret = append(ret, s) - } - return ret -} - -func init() { - for k, v := range noticeDisplaySeverityNames { - namesToDisplaySeverity[v] = DisplaySeverity(k) - } -} diff --git a/postgres/parser/pgnotice/pgnotice.go b/postgres/parser/pgnotice/pgnotice.go deleted file mode 100644 index b6475781c7..0000000000 --- a/postgres/parser/pgnotice/pgnotice.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package pgnotice - -import ( - "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/pgcode" - "github.com/dolthub/doltgresql/postgres/parser/pgerror" -) - -// Newf generates a Notice with a format string. -func Newf(format string, args ...interface{}) error { - err := errors.NewWithDepthf(1, format, args...) - err = pgerror.WithCandidateCode(err, pgcode.SuccessfulCompletion) - err = pgerror.WithSeverity(err, "NOTICE") - return err -} - -// NewWithSeverityf generates a Notice with a format string and severity. -func NewWithSeverityf(severity string, format string, args ...interface{}) error { - err := errors.NewWithDepthf(1, format, args...) - err = pgerror.WithCandidateCode(err, pgcode.SuccessfulCompletion) - err = pgerror.WithSeverity(err, severity) - return err -} diff --git a/postgres/parser/privilege/privilege.go b/postgres/parser/privilege/privilege.go index 80d6ccd14d..4daf82421d 100644 --- a/postgres/parser/privilege/privilege.go +++ b/postgres/parser/privilege/privilege.go @@ -30,9 +30,6 @@ import ( "strings" "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/pgcode" - "github.com/dolthub/doltgresql/postgres/parser/pgerror" ) //go:generate stringer -type=Kind @@ -175,21 +172,6 @@ func (pl List) ToBitField() uint32 { return ret } -// ListFromBitField takes a bitfield of privileges and a ObjectType -// returns a List. It is ordered in increasing value of privilege.Kind. -func ListFromBitField(m uint32, objectType ObjectType) List { - ret := List{} - - privileges := GetValidPrivilegesForObject(objectType) - - for _, p := range privileges { - if m&p.Mask() != 0 { - ret = append(ret, p) - } - } - return ret -} - // ListFromStrings takes a list of strings and attempts to build a list of Kind. // We convert each string to uppercase and search for it in the ByName map. // If an entry is not found in ByName, an error is returned. @@ -204,36 +186,3 @@ func ListFromStrings(strs []string) (List, error) { } return ret, nil } - -// ValidatePrivileges returns an error if any privilege in -// privileges cannot be granted on the given objectType. -// Currently db/schema/table can all be granted the same privileges. -func ValidatePrivileges(privileges List, objectType ObjectType) error { - validPrivs := GetValidPrivilegesForObject(objectType) - for _, priv := range privileges { - // Check if priv is in DBTablePrivileges. - if validPrivs.ToBitField()&priv.Mask() == 0 { - return pgerror.Newf(pgcode.InvalidGrantOperation, - "invalid privilege type %s for %s", priv.String(), objectType) - } - } - - return nil -} - -// GetValidPrivilegesForObject returns the list of valid privileges for the -// specified object type. -func GetValidPrivilegesForObject(objectType ObjectType) List { - switch objectType { - case Table, Database: - return DBTablePrivileges - case Schema: - return SchemaPrivileges - case Type: - return TypePrivileges - case Any: - return AllPrivileges - default: - panic(errors.AssertionFailedf("unknown object type %s", objectType)) - } -} diff --git a/postgres/parser/protoutil/clone.go b/postgres/parser/protoutil/clone.go index abddebe7c6..39f6fe3dc5 100644 --- a/postgres/parser/protoutil/clone.go +++ b/postgres/parser/protoutil/clone.go @@ -26,11 +26,10 @@ package protoutil import ( "reflect" + "sync" "github.com/cockroachdb/errors" "github.com/gogo/protobuf/proto" - - "github.com/dolthub/doltgresql/postgres/parser/syncutil" ) var verbotenKinds = [...]reflect.Kind{ @@ -43,7 +42,7 @@ type typeKey struct { } var types struct { - syncutil.Mutex + sync.Mutex known map[typeKey]reflect.Type } diff --git a/postgres/parser/ring/ring_buffer.go b/postgres/parser/ring/ring_buffer.go deleted file mode 100644 index d99e778142..0000000000 --- a/postgres/parser/ring/ring_buffer.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2018 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package ring - -// Buffer is a deque maintained over a ring buffer. -// -// Note: it is backed by a slice (unlike container/ring which is backed by a -// linked list). -type Buffer struct { - buffer []interface{} - head int // the index of the front of the buffer - tail int // the index of the first position after the end of the buffer - - // Indicates whether the buffer is empty. Necessary to distinguish - // between an empty buffer and a buffer that uses all of its capacity. - nonEmpty bool -} - -// Len returns the number of elements in the Buffer. -func (r *Buffer) Len() int { - if !r.nonEmpty { - return 0 - } - if r.head < r.tail { - return r.tail - r.head - } else if r.head == r.tail { - return cap(r.buffer) - } else { - return cap(r.buffer) + r.tail - r.head - } -} - -// Cap returns the capacity of the Buffer. -func (r *Buffer) Cap() int { - return cap(r.buffer) -} - -// Get returns an element at position pos in the Buffer (zero-based). -func (r *Buffer) Get(pos int) interface{} { - if !r.nonEmpty || pos < 0 || pos >= r.Len() { - panic("index out of bounds") - } - return r.buffer[(pos+r.head)%cap(r.buffer)] -} - -// GetFirst returns an element at the front of the Buffer. -func (r *Buffer) GetFirst() interface{} { - if !r.nonEmpty { - panic("getting first from empty ring buffer") - } - return r.buffer[r.head] -} - -// GetLast returns an element at the front of the Buffer. -func (r *Buffer) GetLast() interface{} { - if !r.nonEmpty { - panic("getting last from empty ring buffer") - } - return r.buffer[(cap(r.buffer)+r.tail-1)%cap(r.buffer)] -} - -func (r *Buffer) grow(n int) { - newBuffer := make([]interface{}, n) - if r.head < r.tail { - copy(newBuffer[:r.Len()], r.buffer[r.head:r.tail]) - } else { - copy(newBuffer[:cap(r.buffer)-r.head], r.buffer[r.head:]) - copy(newBuffer[cap(r.buffer)-r.head:r.Len()], r.buffer[:r.tail]) - } - r.head = 0 - r.tail = cap(r.buffer) - r.buffer = newBuffer -} - -func (r *Buffer) maybeGrow() { - if r.Len() != cap(r.buffer) { - return - } - n := 2 * cap(r.buffer) - if n == 0 { - n = 1 - } - r.grow(n) -} - -// AddFirst add element to the front of the Buffer and doubles it's underlying -// slice if necessary. -func (r *Buffer) AddFirst(element interface{}) { - r.maybeGrow() - r.head = (cap(r.buffer) + r.head - 1) % cap(r.buffer) - r.buffer[r.head] = element - r.nonEmpty = true -} - -// AddLast adds element to the end of the Buffer and doubles it's underlying -// slice if necessary. -func (r *Buffer) AddLast(element interface{}) { - r.maybeGrow() - r.buffer[r.tail] = element - r.tail = (r.tail + 1) % cap(r.buffer) - r.nonEmpty = true -} - -// RemoveFirst removes a single element from the front of the Buffer. -func (r *Buffer) RemoveFirst() { - if r.Len() == 0 { - panic("removing first from empty ring buffer") - } - r.buffer[r.head] = nil - r.head = (r.head + 1) % cap(r.buffer) - if r.head == r.tail { - r.nonEmpty = false - } -} - -// RemoveLast removes a single element from the end of the Buffer. -func (r *Buffer) RemoveLast() { - if r.Len() == 0 { - panic("removing last from empty ring buffer") - } - lastPos := (cap(r.buffer) + r.tail - 1) % cap(r.buffer) - r.buffer[lastPos] = nil - r.tail = lastPos - if r.tail == r.head { - r.nonEmpty = false - } -} - -// Reserve reserves the provided number of elemnets in the Buffer. It is an -// error to reserve a size less than the Buffer's current length. -func (r *Buffer) Reserve(n int) { - if n < r.Len() { - panic("reserving fewer elements than current length") - } else if n > cap(r.buffer) { - r.grow(n) - } -} - -// Reset makes Buffer treat its underlying memory as if it were empty. This -// allows for reusing the same memory again without explicitly removing old -// elements. -func (r *Buffer) Reset() { - r.head = 0 - r.tail = 0 - r.nonEmpty = false -} diff --git a/postgres/parser/sem/tree/aggregate_funcs.go b/postgres/parser/sem/tree/aggregate_funcs.go deleted file mode 100644 index cae3cb2dfd..0000000000 --- a/postgres/parser/sem/tree/aggregate_funcs.go +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2017 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -import "context" - -// AggregateFunc accumulates the result of a function of a Datum. -type AggregateFunc interface { - // Add accumulates the passed datums into the AggregateFunc. - // Most implementations require one and only one firstArg argument. - // If an aggregate function requires more than one argument, - // all additional arguments (after firstArg) are passed in as a - // variadic collection, otherArgs. - // This interface (as opposed to `args ...Datum`) avoids unnecessary - // allocation of otherArgs in the majority of cases. - Add(_ context.Context, firstArg Datum, otherArgs ...Datum) error - - // Result returns the current value of the accumulation. This value - // will be a deep copy of any AggregateFunc internal state, so that - // it will not be mutated by additional calls to Add. - Result() (Datum, error) - - // Reset resets the aggregate function which allows for reusing the same - // instance for computation without the need to create a new instance. - // Any memory is kept, if possible. - Reset(context.Context) - - // Close closes out the AggregateFunc and allows it to release any memory it - // requested during aggregation, and must be called upon completion of the - // aggregation. - Close(context.Context) - - // Size returns the size of the AggregateFunc implementation in bytes. It - // does *not* account for additional memory used during accumulation. - Size() int64 -} diff --git a/postgres/parser/sem/tree/as_of.go b/postgres/parser/sem/tree/as_of.go deleted file mode 100644 index 371c6b9555..0000000000 --- a/postgres/parser/sem/tree/as_of.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2018 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -import ( - "context" - "strconv" - "strings" - "time" - - "github.com/cockroachdb/apd/v2" - "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/duration" - "github.com/dolthub/doltgresql/postgres/parser/hlc" - "github.com/dolthub/doltgresql/postgres/parser/pgcode" - "github.com/dolthub/doltgresql/postgres/parser/pgerror" - "github.com/dolthub/doltgresql/postgres/parser/types" -) - -// FollowerReadTimestampFunctionName is the name of the function which can be -// used with AOST clauses to generate a timestamp likely to be safe for follower -// reads. -const FollowerReadTimestampFunctionName = "follower_read_timestamp" - -// FollowerReadTimestampExperimentalFunctionName is the name of the old -// "experimental_" function, which we keep for backwards compatibility. -const FollowerReadTimestampExperimentalFunctionName = "experimental_follower_read_timestamp" - -var errInvalidExprForAsOf = errors.Errorf("AS OF SYSTEM TIME: only constant expressions or " + - FollowerReadTimestampFunctionName + " are allowed") - -// EvalAsOfTimestamp evaluates the timestamp argument to an AS OF SYSTEM TIME query. -func EvalAsOfTimestamp( - ctx context.Context, asOf AsOfClause, semaCtx *SemaContext, evalCtx *EvalContext, -) (tsss hlc.Timestamp, err error) { - // We need to save and restore the previous value of the field in - // semaCtx in case we are recursively called within a subquery - // context. - scalarProps := &semaCtx.Properties - defer scalarProps.Restore(*scalarProps) - scalarProps.Require("AS OF SYSTEM TIME", RejectSpecial|RejectSubqueries) - - // In order to support the follower reads feature we permit this expression - // to be a simple invocation of the `FollowerReadTimestampFunction`. - // Over time we could expand the set of allowed functions or expressions. - // All non-function expressions must be const and must TypeCheck into a - // string. - var te TypedExpr - if fe, ok := asOf.Expr.(*FuncExpr); ok { - def, err := fe.Func.Resolve(semaCtx.SearchPath) - if err != nil { - return hlc.Timestamp{}, errInvalidExprForAsOf - } - if def.Name != FollowerReadTimestampFunctionName && - def.Name != FollowerReadTimestampExperimentalFunctionName { - return hlc.Timestamp{}, errInvalidExprForAsOf - } - if te, err = fe.TypeCheck(ctx, semaCtx, types.TimestampTZ); err != nil { - return hlc.Timestamp{}, err - } - } else { - var err error - te, err = asOf.Expr.TypeCheck(ctx, semaCtx, types.String) - if err != nil { - return hlc.Timestamp{}, err - } - if !IsConst(evalCtx, te) { - return hlc.Timestamp{}, errInvalidExprForAsOf - } - } - - d, err := te.Eval(evalCtx) - if err != nil { - return hlc.Timestamp{}, err - } - - stmtTimestamp := evalCtx.GetStmtTimestamp() - ts, err := DatumToHLC(evalCtx, stmtTimestamp, d) - return ts, errors.Wrap(err, "AS OF SYSTEM TIME") -} - -// DatumToHLC performs the conversion from a Datum to an HLC timestamp. -func DatumToHLC(evalCtx *EvalContext, stmtTimestamp time.Time, d Datum) (hlc.Timestamp, error) { - ts := hlc.Timestamp{} - var convErr error - switch d := d.(type) { - case *DString: - s := string(*d) - // Attempt to parse as timestamp. - if dt, _, err := ParseDTimestamp(evalCtx, s, time.Nanosecond); err == nil { - ts.WallTime = dt.Time.UnixNano() - break - } - // Attempt to parse as a decimal. - if dec, _, err := apd.NewFromString(s); err == nil { - ts, convErr = DecimalToHLC(dec) - break - } - // Attempt to parse as an interval. - if iv, err := ParseDInterval(s); err == nil { - if (iv.Duration == duration.Duration{}) { - convErr = errors.Errorf("interval value %v too small, absolute value must be >= %v", d, time.Microsecond) - } - ts.WallTime = duration.Add(stmtTimestamp, iv.Duration).UnixNano() - break - } - convErr = errors.Errorf("value is neither timestamp, decimal, nor interval") - case *DTimestamp: - ts.WallTime = d.UnixNano() - case *DTimestampTZ: - ts.WallTime = d.UnixNano() - case *DInt: - ts.WallTime = int64(*d) - case *DDecimal: - ts, convErr = DecimalToHLC(&d.Decimal) - case *DInterval: - ts.WallTime = duration.Add(stmtTimestamp, d.Duration).UnixNano() - default: - convErr = errors.WithSafeDetails( - errors.Errorf("expected timestamp, decimal, or interval, got %s", d.ResolvedType()), - "go type: %T", d) - } - if convErr != nil { - return ts, convErr - } - zero := hlc.Timestamp{} - if ts == zero { - return ts, errors.Errorf("zero timestamp is invalid") - } else if ts.Less(zero) { - return ts, errors.Errorf("timestamp before 1970-01-01T00:00:00Z is invalid") - } - return ts, nil -} - -// DecimalToHLC performs the conversion from an inputted DECIMAL datum for an -// AS OF SYSTEM TIME query to an HLC timestamp. -func DecimalToHLC(d *apd.Decimal) (hlc.Timestamp, error) { - // Format the decimal into a string and split on `.` to extract the nanosecond - // walltime and logical tick parts. - // TODO(mjibson): use d.Modf() instead of converting to a string. - s := d.Text('f') - parts := strings.SplitN(s, ".", 2) - nanos, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil { - return hlc.Timestamp{}, pgerror.Wrapf(err, pgcode.Syntax, "parsing argument") - } - var logical int64 - if len(parts) > 1 { - // logicalLength is the number of decimal digits expected in the - // logical part to the right of the decimal. See the implementation of - // cluster_logical_timestamp(). - const logicalLength = 10 - p := parts[1] - if lp := len(p); lp > logicalLength { - return hlc.Timestamp{}, pgerror.Newf(pgcode.Syntax, "logical part has too many digits") - } else if lp < logicalLength { - p += strings.Repeat("0", logicalLength-lp) - } - logical, err = strconv.ParseInt(p, 10, 32) - if err != nil { - return hlc.Timestamp{}, pgerror.Wrapf(err, pgcode.Syntax, "parsing argument") - } - } - return hlc.Timestamp{ - WallTime: nanos, - Logical: int32(logical), - }, nil -} diff --git a/postgres/parser/sem/tree/casts.go b/postgres/parser/sem/tree/casts.go index ce5213bbf2..dbda06bb30 100644 --- a/postgres/parser/sem/tree/casts.go +++ b/postgres/parser/sem/tree/casts.go @@ -25,24 +25,7 @@ package tree import ( - "math" - "math/big" - "strconv" - "strings" - "time" - - "github.com/lib/pq/oid" - - "github.com/dolthub/doltgresql/postgres/parser/duration" - "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/geo/geopb" - "github.com/dolthub/doltgresql/postgres/parser/lex" - "github.com/dolthub/doltgresql/postgres/parser/pgcode" - "github.com/dolthub/doltgresql/postgres/parser/pgdate" - "github.com/dolthub/doltgresql/postgres/parser/pgerror" - "github.com/dolthub/doltgresql/postgres/parser/timeofday" "github.com/dolthub/doltgresql/postgres/parser/types" - "github.com/dolthub/doltgresql/postgres/parser/utils" ) type castInfo struct { @@ -337,848 +320,3 @@ func init() { func lookupCast(from, to types.Family) *castInfo { return castsMap[castsMapKey{from: from, to: to}] } - -// LookupCastVolatility returns the volatility of a valid cast. -func LookupCastVolatility(from, to *types.T) (_ Volatility, ok bool) { - fromFamily := from.Family() - toFamily := to.Family() - // Special case for casting between arrays. - if fromFamily == types.ArrayFamily && toFamily == types.ArrayFamily { - return LookupCastVolatility(from.ArrayContents(), to.ArrayContents()) - } - // Special case for casting between tuples. - if fromFamily == types.TupleFamily && toFamily == types.TupleFamily { - fromTypes := from.TupleContents() - toTypes := to.TupleContents() - if len(fromTypes) != len(toTypes) { - return 0, false - } - maxVolatility := VolatilityLeakProof - for i := range fromTypes { - v, ok := LookupCastVolatility(fromTypes[i], toTypes[i]) - if !ok { - return 0, false - } - if v > maxVolatility { - maxVolatility = v - } - } - return maxVolatility, true - } - cast := lookupCast(fromFamily, toFamily) - if cast == nil { - return 0, false - } - return cast.volatility, true -} - -// PerformCast performs a cast from the provided Datum to the specified -// types.T. -func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) { - switch t.Family() { - case types.BitFamily: - switch v := d.(type) { - case *DBitArray: - if t.Width() == 0 || v.BitLen() == uint(t.Width()) { - return d, nil - } - var a DBitArray - switch t.Oid() { - case oid.T_varbit: - // VARBITs do not have padding attached. - a.BitArray = v.BitArray.Clone() - if uint(t.Width()) < a.BitArray.BitLen() { - a.BitArray = a.BitArray.ToWidth(uint(t.Width())) - } - default: - a.BitArray = v.BitArray.Clone().ToWidth(uint(t.Width())) - } - return &a, nil - case *DInt: - return NewDBitArrayFromInt(int64(*v), uint(t.Width())) - case *DString: - res, err := utils.Parse(string(*v)) - if err != nil { - return nil, err - } - if t.Width() > 0 { - res = res.ToWidth(uint(t.Width())) - } - return &DBitArray{BitArray: res}, nil - case *DCollatedString: - res, err := utils.Parse(v.Contents) - if err != nil { - return nil, err - } - if t.Width() > 0 { - res = res.ToWidth(uint(t.Width())) - } - return &DBitArray{BitArray: res}, nil - } - - case types.BoolFamily: - switch v := d.(type) { - case *DBool: - return d, nil - case *DInt: - return MakeDBool(*v != 0), nil - case *DFloat: - return MakeDBool(*v != 0), nil - case *DDecimal: - return MakeDBool(v.Sign() != 0), nil - case *DString: - return ParseDBool(string(*v)) - case *DCollatedString: - return ParseDBool(v.Contents) - } - - case types.IntFamily: - var res *DInt - switch v := d.(type) { - case *DBitArray: - res = v.AsDInt(uint(t.Width())) - case *DBool: - if *v { - res = NewDInt(1) - } else { - res = DZero - } - case *DInt: - // TODO(knz): enforce the coltype width here. - res = v - case *DFloat: - f := float64(*v) - // Use `<=` and `>=` here instead of just `<` and `>` because when - // math.MaxInt64 and math.MinInt64 are converted to float64s, they are - // rounded to numbers with larger absolute values. Note that the first - // next FP value after and strictly greater than float64(math.MinInt64) - // is -9223372036854774784 (= float64(math.MinInt64)+513) and the first - // previous value and strictly smaller than float64(math.MaxInt64) - // is 9223372036854774784 (= float64(math.MaxInt64)-513), and both are - // convertible to int without overflow. - if math.IsNaN(f) || f <= float64(math.MinInt64) || f >= float64(math.MaxInt64) { - return nil, ErrIntOutOfRange - } - res = NewDInt(DInt(f)) - case *DDecimal: - d := ctx.getTmpDec() - _, err := DecimalCtx.RoundToIntegralValue(d, &v.Decimal) - if err != nil { - return nil, err - } - i, err := d.Int64() - if err != nil { - return nil, ErrIntOutOfRange - } - res = NewDInt(DInt(i)) - case *DString: - var err error - if res, err = ParseDInt(string(*v)); err != nil { - return nil, err - } - case *DCollatedString: - var err error - if res, err = ParseDInt(v.Contents); err != nil { - return nil, err - } - case *DTimestamp: - res = NewDInt(DInt(v.Unix())) - case *DTimestampTZ: - res = NewDInt(DInt(v.Unix())) - case *DDate: - // TODO(mjibson): This cast is unsupported by postgres. Should we remove ours? - if !v.IsFinite() { - return nil, ErrIntOutOfRange - } - res = NewDInt(DInt(v.UnixEpochDays())) - case *DInterval: - iv, ok := v.AsInt64() - if !ok { - return nil, ErrIntOutOfRange - } - res = NewDInt(DInt(iv)) - case *DOid: - res = &v.DInt - } - if res != nil { - return res, nil - } - - case types.EnumFamily: - switch v := d.(type) { - case *DString: - return MakeDEnumFromLogicalRepresentation(t, string(*v)) - case *DBytes: - return MakeDEnumFromPhysicalRepresentation(t, []byte(*v)) - case *DEnum: - return d, nil - } - - case types.FloatFamily: - switch v := d.(type) { - case *DBool: - if *v { - return NewDFloat(1), nil - } - return NewDFloat(0), nil - case *DInt: - return NewDFloat(DFloat(*v)), nil - case *DFloat: - return d, nil - case *DDecimal: - f, err := v.Float64() - if err != nil { - return nil, ErrFloatOutOfRange - } - return NewDFloat(DFloat(f)), nil - case *DString: - return ParseDFloat(string(*v)) - case *DCollatedString: - return ParseDFloat(v.Contents) - case *DTimestamp: - micros := float64(v.Nanosecond() / int(time.Microsecond)) - return NewDFloat(DFloat(float64(v.Unix()) + micros*1e-6)), nil - case *DTimestampTZ: - micros := float64(v.Nanosecond() / int(time.Microsecond)) - return NewDFloat(DFloat(float64(v.Unix()) + micros*1e-6)), nil - case *DDate: - // TODO(mjibson): This cast is unsupported by postgres. Should we remove ours? - if !v.IsFinite() { - return nil, ErrFloatOutOfRange - } - return NewDFloat(DFloat(float64(v.UnixEpochDays()))), nil - case *DInterval: - return NewDFloat(DFloat(v.AsFloat64())), nil - } - - case types.DecimalFamily: - var dd DDecimal - var err error - unset := false - switch v := d.(type) { - case *DBool: - if *v { - dd.SetInt64(1) - } - case *DInt: - dd.SetInt64(int64(*v)) - case *DDate: - // TODO(mjibson): This cast is unsupported by postgres. Should we remove ours? - if !v.IsFinite() { - return nil, errDecOutOfRange - } - dd.SetInt64(v.UnixEpochDays()) - case *DFloat: - _, err = dd.SetFloat64(float64(*v)) - case *DDecimal: - // Small optimization to avoid copying into dd in normal case. - if t.Precision() == 0 { - return d, nil - } - dd.Set(&v.Decimal) - case *DString: - err = dd.SetString(string(*v)) - case *DCollatedString: - err = dd.SetString(v.Contents) - case *DTimestamp: - val := &dd.Coeff - val.SetInt64(v.Unix()) - val.Mul(val, big10E6) - micros := v.Nanosecond() / int(time.Microsecond) - val.Add(val, big.NewInt(int64(micros))) - dd.Exponent = -6 - case *DTimestampTZ: - val := &dd.Coeff - val.SetInt64(v.Unix()) - val.Mul(val, big10E6) - micros := v.Nanosecond() / int(time.Microsecond) - val.Add(val, big.NewInt(int64(micros))) - dd.Exponent = -6 - case *DInterval: - v.AsBigInt(&dd.Coeff) - dd.Exponent = -9 - default: - unset = true - } - if err != nil { - return nil, err - } - if !unset { - // dd.Coeff must be positive. If it was set to a negative value - // above, transfer the sign to dd.Negative. - if dd.Coeff.Sign() < 0 { - dd.Negative = true - dd.Coeff.Abs(&dd.Coeff) - } - err = LimitDecimalWidth(&dd.Decimal, int(t.Precision()), int(t.Scale())) - return &dd, err - } - - case types.StringFamily, types.CollatedStringFamily: - var s string - switch t := d.(type) { - case *DBitArray: - s = t.BitArray.String() - case *DFloat: - s = strconv.FormatFloat(float64(*t), 'g', - ctx.SessionData.DataConversion.GetFloatPrec(), 64) - case *DBool, *DInt, *DDecimal: - s = d.String() - case *DTimestamp, *DDate, *DTime, *DTimeTZ, *DGeography, *DGeometry, *DBox2D: - s = AsStringWithFlags(d, FmtBareStrings) - case *DTimestampTZ: - // Convert to context timezone for correct display. - ts, err := MakeDTimestampTZ(t.In(ctx.GetLocation()), time.Microsecond) - if err != nil { - return nil, err - } - s = AsStringWithFlags( - ts, - FmtBareStrings, - ) - case *DTuple: - s = AsStringWithFlags(d, FmtPgwireText) - case *DArray: - s = AsStringWithFlags(d, FmtPgwireText) - case *DInterval: - // When converting an interval to string, we need a string representation - // of the duration (e.g. "5s") and not of the interval itself (e.g. - // "INTERVAL '5s'"). - s = t.ValueAsString() - case *DUuid: - s = t.UUID.String() - case *DIPAddr: - s = AsStringWithFlags(d, FmtBareStrings) - case *DString: - s = string(*t) - case *DCollatedString: - s = t.Contents - case *DBytes: - s = lex.EncodeByteArrayToRawBytes(string(*t), - ctx.SessionData.DataConversion.BytesEncodeFormat, false /* skipHexPrefix */) - case *DOid: - s = t.String() - case *DJSON: - s = t.JSON.String() - case *DEnum: - s = t.LogicalRep - } - switch t.Family() { - case types.StringFamily: - if t.Oid() == oid.T_name { - return NewDName(s), nil - } - - // bpchar types truncate trailing whitespace. - if t.Oid() == oid.T_bpchar { - s = strings.TrimRight(s, " ") - } - - // If the string type specifies a limit we truncate to that limit: - // 'hello'::CHAR(2) -> 'he' - // This is true of all the string type variants. - if t.Width() > 0 { - s = truncateString(s, int(t.Width())) - } - return NewDString(s), nil - case types.CollatedStringFamily: - // bpchar types truncate trailing whitespace. - if t.Oid() == oid.T_bpchar { - s = strings.TrimRight(s, " ") - } - // Ditto truncation like for TString. - if t.Width() > 0 { - s = truncateString(s, int(t.Width())) - } - return NewDCollatedString(s, t.Locale(), &ctx.CollationEnv) - } - - case types.BytesFamily: - switch t := d.(type) { - case *DString: - return ParseDByte(string(*t)) - case *DCollatedString: - return NewDBytes(DBytes(t.Contents)), nil - case *DUuid: - return NewDBytes(DBytes(t.GetBytes())), nil - case *DBytes: - return d, nil - case *DGeography: - return NewDBytes(DBytes(t.Geography.EWKB())), nil - case *DGeometry: - return NewDBytes(DBytes(t.Geometry.EWKB())), nil - } - - case types.UuidFamily: - switch t := d.(type) { - case *DString: - return ParseDUuidFromString(string(*t)) - case *DCollatedString: - return ParseDUuidFromString(t.Contents) - case *DBytes: - return ParseDUuidFromBytes([]byte(*t)) - case *DUuid: - return d, nil - } - - case types.INetFamily: - switch t := d.(type) { - case *DString: - return ParseDIPAddrFromINetString(string(*t)) - case *DCollatedString: - return ParseDIPAddrFromINetString(t.Contents) - case *DIPAddr: - return d, nil - } - - case types.Box2DFamily: - switch d := d.(type) { - case *DString: - return ParseDBox2D(string(*d)) - case *DCollatedString: - return ParseDBox2D(d.Contents) - case *DBox2D: - return d, nil - case *DGeometry: - bbox := d.CartesianBoundingBox() - if bbox == nil { - return DNull, nil - } - return NewDBox2D(*bbox), nil - } - - case types.GeographyFamily: - switch d := d.(type) { - case *DString: - return ParseDGeography(string(*d)) - case *DCollatedString: - return ParseDGeography(d.Contents) - case *DGeography: - if err := geo.SpatialObjectFitsColumnMetadata( - d.Geography.SpatialObject(), - t.InternalType.GeoMetadata.SRID, - t.InternalType.GeoMetadata.ShapeType, - ); err != nil { - return nil, err - } - return d, nil - case *DGeometry: - g, err := d.AsGeography() - if err != nil { - return nil, err - } - if err := geo.SpatialObjectFitsColumnMetadata( - g.SpatialObject(), - t.InternalType.GeoMetadata.SRID, - t.InternalType.GeoMetadata.ShapeType, - ); err != nil { - return nil, err - } - return &DGeography{g}, nil - case *DJSON: - t, err := d.AsText() - if err != nil { - return nil, err - } - g, err := geo.ParseGeographyFromGeoJSON([]byte(*t)) - if err != nil { - return nil, err - } - return &DGeography{g}, nil - case *DBytes: - g, err := geo.ParseGeographyFromEWKB(geopb.EWKB(*d)) - if err != nil { - return nil, err - } - return &DGeography{g}, nil - } - case types.GeometryFamily: - switch d := d.(type) { - case *DString: - return ParseDGeometry(string(*d)) - case *DCollatedString: - return ParseDGeometry(d.Contents) - case *DGeometry: - if err := geo.SpatialObjectFitsColumnMetadata( - d.Geometry.SpatialObject(), - t.InternalType.GeoMetadata.SRID, - t.InternalType.GeoMetadata.ShapeType, - ); err != nil { - return nil, err - } - return d, nil - case *DGeography: - if err := geo.SpatialObjectFitsColumnMetadata( - d.Geography.SpatialObject(), - t.InternalType.GeoMetadata.SRID, - t.InternalType.GeoMetadata.ShapeType, - ); err != nil { - return nil, err - } - g, err := d.AsGeometry() - if err != nil { - return nil, err - } - return &DGeometry{g}, nil - case *DJSON: - t, err := d.AsText() - if err != nil { - return nil, err - } - g, err := geo.ParseGeometryFromGeoJSON([]byte(*t)) - if err != nil { - return nil, err - } - return &DGeometry{g}, nil - case *DBox2D: - g, err := geo.MakeGeometryFromGeomT(d.ToGeomT(geopb.DefaultGeometrySRID)) - if err != nil { - return nil, err - } - return &DGeometry{g}, nil - case *DBytes: - g, err := geo.ParseGeometryFromEWKB(geopb.EWKB(*d)) - if err != nil { - return nil, err - } - return &DGeometry{g}, nil - } - - case types.DateFamily: - switch d := d.(type) { - case *DString: - res, _, err := ParseDDate(ctx, string(*d)) - return res, err - case *DCollatedString: - res, _, err := ParseDDate(ctx, d.Contents) - return res, err - case *DDate: - return d, nil - case *DInt: - // TODO(mjibson): This cast is unsupported by postgres. Should we remove ours? - t, err := pgdate.MakeDateFromUnixEpoch(int64(*d)) - return NewDDate(t), err - case *DTimestampTZ: - return NewDDateFromTime(d.Time.In(ctx.GetLocation())) - case *DTimestamp: - return NewDDateFromTime(d.Time) - } - - case types.TimeFamily: - roundTo := TimeFamilyPrecisionToRoundDuration(t.Precision()) - switch d := d.(type) { - case *DString: - res, _, err := ParseDTime(ctx, string(*d), roundTo) - return res, err - case *DCollatedString: - res, _, err := ParseDTime(ctx, d.Contents, roundTo) - return res, err - case *DTime: - return d.Round(roundTo), nil - case *DTimeTZ: - return MakeDTime(d.TimeOfDay.Round(roundTo)), nil - case *DTimestamp: - return MakeDTime(timeofday.FromTime(d.Time).Round(roundTo)), nil - case *DTimestampTZ: - // Strip time zone. Times don't carry their location. - stripped, err := d.stripTimeZone(ctx) - if err != nil { - return nil, err - } - return MakeDTime(timeofday.FromTime(stripped.Time).Round(roundTo)), nil - case *DInterval: - return MakeDTime(timeofday.Min.Add(d.Duration).Round(roundTo)), nil - } - - case types.TimeTZFamily: - roundTo := TimeFamilyPrecisionToRoundDuration(t.Precision()) - switch d := d.(type) { - case *DString: - res, _, err := ParseDTimeTZ(ctx, string(*d), roundTo) - return res, err - case *DCollatedString: - res, _, err := ParseDTimeTZ(ctx, d.Contents, roundTo) - return res, err - case *DTime: - return NewDTimeTZFromLocation(timeofday.TimeOfDay(*d).Round(roundTo), ctx.GetLocation()), nil - case *DTimeTZ: - return d.Round(roundTo), nil - case *DTimestampTZ: - return NewDTimeTZFromTime(d.Time.In(ctx.GetLocation()).Round(roundTo)), nil - } - - case types.TimestampFamily: - roundTo := TimeFamilyPrecisionToRoundDuration(t.Precision()) - // TODO(knz): Timestamp from float, decimal. - switch d := d.(type) { - case *DString: - res, _, err := ParseDTimestamp(ctx, string(*d), roundTo) - return res, err - case *DCollatedString: - res, _, err := ParseDTimestamp(ctx, d.Contents, roundTo) - return res, err - case *DDate: - t, err := d.ToTime() - if err != nil { - return nil, err - } - return MakeDTimestamp(t, roundTo) - case *DInt: - return MakeDTimestamp(time.Unix(int64(*d), 0).UTC(), roundTo) - case *DTimestamp: - return d.Round(roundTo) - case *DTimestampTZ: - // Strip time zone. Timestamps don't carry their location. - stripped, err := d.stripTimeZone(ctx) - if err != nil { - return nil, err - } - return stripped.Round(roundTo) - } - - case types.TimestampTZFamily: - roundTo := TimeFamilyPrecisionToRoundDuration(t.Precision()) - // TODO(knz): TimestampTZ from float, decimal. - switch d := d.(type) { - case *DString: - res, _, err := ParseDTimestampTZ(ctx, string(*d), roundTo) - return res, err - case *DCollatedString: - res, _, err := ParseDTimestampTZ(ctx, d.Contents, roundTo) - return res, err - case *DDate: - t, err := d.ToTime() - if err != nil { - return nil, err - } - _, before := t.Zone() - _, after := t.In(ctx.GetLocation()).Zone() - return MakeDTimestampTZ(t.Add(time.Duration(before-after)*time.Second), roundTo) - case *DTimestamp: - _, before := d.Time.Zone() - _, after := d.Time.In(ctx.GetLocation()).Zone() - return MakeDTimestampTZ(d.Time.Add(time.Duration(before-after)*time.Second), roundTo) - case *DInt: - return MakeDTimestampTZ(time.Unix(int64(*d), 0).UTC(), roundTo) - case *DTimestampTZ: - return d.Round(roundTo) - } - - case types.IntervalFamily: - itm, err := t.IntervalTypeMetadata() - if err != nil { - return nil, err - } - switch v := d.(type) { - case *DString: - return ParseDIntervalWithTypeMetadata(string(*v), itm) - case *DCollatedString: - return ParseDIntervalWithTypeMetadata(v.Contents, itm) - case *DInt: - return NewDInterval(duration.FromInt64(int64(*v)), itm), nil - case *DFloat: - return NewDInterval(duration.FromFloat64(float64(*v)), itm), nil - case *DTime: - return NewDInterval(duration.MakeDuration(int64(*v)*1000, 0, 0), itm), nil - case *DDecimal: - d := ctx.getTmpDec() - dnanos := v.Decimal - dnanos.Exponent += 9 - // We need HighPrecisionCtx because duration values can contain - // upward of 35 decimal digits and DecimalCtx only provides 25. - _, err := HighPrecisionCtx.Quantize(d, &dnanos, 0) - if err != nil { - return nil, err - } - if dnanos.Negative { - d.Coeff.Neg(&d.Coeff) - } - dv, ok := duration.FromBigInt(&d.Coeff) - if !ok { - return nil, errDecOutOfRange - } - return NewDInterval(dv, itm), nil - case *DInterval: - return NewDInterval(v.Duration, itm), nil - } - case types.JsonFamily: - switch v := d.(type) { - case *DString: - return ParseDJSON(string(*v)) - case *DJSON: - return v, nil - case *DGeography: - j, err := geo.SpatialObjectToGeoJSON(v.Geography.SpatialObject(), -1, geo.SpatialObjectToGeoJSONFlagZero) - if err != nil { - return nil, err - } - return ParseDJSON(string(j)) - case *DGeometry: - j, err := geo.SpatialObjectToGeoJSON(v.Geometry.SpatialObject(), -1, geo.SpatialObjectToGeoJSONFlagZero) - if err != nil { - return nil, err - } - return ParseDJSON(string(j)) - } - case types.ArrayFamily: - switch v := d.(type) { - case *DString: - res, _, err := ParseDArrayFromString(ctx, string(*v), t.ArrayContents()) - return res, err - case *DArray: - dcast := NewDArray(t.ArrayContents()) - for _, e := range v.Array { - ecast := DNull - if e != DNull { - var err error - ecast, err = PerformCast(ctx, e, t.ArrayContents()) - if err != nil { - return nil, err - } - } - - if err := dcast.Append(ecast); err != nil { - return nil, err - } - } - return dcast, nil - } - case types.OidFamily: - switch v := d.(type) { - case *DOid: - switch t.Oid() { - case oid.T_oid: - return &DOid{semanticType: t, DInt: v.DInt}, nil - case oid.T_regtype: - // Mapping an oid to a regtype is easy: we have a hardcoded map. - typ, ok := types.OidToType[oid.Oid(v.DInt)] - ret := &DOid{semanticType: t, DInt: v.DInt} - if !ok { - return ret, nil - } - ret.name = typ.PGName() - return ret, nil - default: - oid, err := queryOid(ctx, t, v) - if err != nil { - oid = NewDOid(v.DInt) - oid.semanticType = t - } - return oid, nil - } - case *DInt: - switch t.Oid() { - case oid.T_oid: - return &DOid{semanticType: t, DInt: *v}, nil - default: - tmpOid := NewDOid(*v) - oid, err := queryOid(ctx, t, tmpOid) - if err != nil { - oid = tmpOid - oid.semanticType = t - } - return oid, nil - } - case *DString: - s := string(*v) - // Trim whitespace and unwrap outer quotes if necessary. - // This is required to mimic postgres. - s = strings.TrimSpace(s) - origS := s - if len(s) > 1 && s[0] == '"' && s[len(s)-1] == '"' { - s = s[1 : len(s)-1] - } - - switch t.Oid() { - case oid.T_oid: - i, err := ParseDInt(s) - if err != nil { - return nil, err - } - return &DOid{semanticType: t, DInt: *i}, nil - case oid.T_regproc, oid.T_regprocedure: - // Trim procedure type parameters, e.g. `max(int)` becomes `max`. - // Postgres only does this when the cast is ::regprocedure, but we're - // going to always do it. - // We additionally do not yet implement disambiguation based on type - // parameters: we return the match iff there is exactly one. - s = pgSignatureRegexp.ReplaceAllString(s, "$1") - // Resolve function name. - substrs := strings.Split(s, ".") - if len(substrs) > 3 { - // A fully qualified function name in pg's dialect can contain - // at most 3 parts: db.schema.funname. - // For example mydb.pg_catalog.max(). - // Anything longer is always invalid. - return nil, pgerror.Newf(pgcode.Syntax, - "invalid function name: %s", s) - } - name := UnresolvedName{NumParts: len(substrs)} - for i := 0; i < len(substrs); i++ { - name.Parts[i] = substrs[len(substrs)-1-i] - } - funcDef, err := name.ResolveFunction(ctx.SessionData.SearchPath) - if err != nil { - return nil, err - } - return queryOid(ctx, t, NewDString(funcDef.Name)) - case oid.T_regtype: - parsedTyp, err := ctx.Planner.ParseType(s) - if err == nil { - return &DOid{ - semanticType: t, - DInt: DInt(parsedTyp.Oid()), - name: parsedTyp.SQLStandardName(), - }, nil - } - // Fall back to searching pg_type, since we don't provide syntax for - // every postgres type that we understand OIDs for. - // Trim type modifiers, e.g. `numeric(10,3)` becomes `numeric`. - s = pgSignatureRegexp.ReplaceAllString(s, "$1") - dOid, missingTypeErr := queryOid(ctx, t, NewDString(s)) - if missingTypeErr == nil { - return dOid, missingTypeErr - } - // Fall back to some special cases that we support for compatibility - // only. Client use syntax like 'sometype'::regtype to produce the oid - // for a type that they want to search a catalog table for. Since we - // don't support that type, we return an artificial OID that will never - // match anything. - switch s { - // We don't support triggers, but some tools search for them - // specifically. - case "trigger": - default: - return nil, missingTypeErr - } - return &DOid{ - semanticType: t, - // Types we don't support get OID -1, so they won't match anything - // in catalogs. - DInt: -1, - name: s, - }, nil - - case oid.T_regclass: - tn, err := ctx.Planner.ParseQualifiedTableName(origS) - if err != nil { - return nil, err - } - id, err := ctx.Planner.ResolveTableName(ctx.Ctx(), tn) - if err != nil { - return nil, err - } - return &DOid{ - semanticType: t, - DInt: DInt(id), - name: tn.ObjectName.String(), - }, nil - default: - return queryOid(ctx, t, NewDString(s)) - } - } - } - - return nil, pgerror.Newf( - pgcode.CannotCoerce, "invalid cast: %s -> %s", d.ResolvedType(), t) -} diff --git a/postgres/parser/sem/tree/constant_eval.go b/postgres/parser/sem/tree/constant_eval.go deleted file mode 100644 index c703976475..0000000000 --- a/postgres/parser/sem/tree/constant_eval.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2018 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -// ConstantEvalVisitor replaces constant TypedExprs with the result of Eval. -type ConstantEvalVisitor struct { - ctx *EvalContext - err error - - fastIsConstVisitor fastIsConstVisitor -} - -var _ Visitor = &ConstantEvalVisitor{} - -// MakeConstantEvalVisitor creates a ConstantEvalVisitor instance. -func MakeConstantEvalVisitor(ctx *EvalContext) ConstantEvalVisitor { - return ConstantEvalVisitor{ctx: ctx, fastIsConstVisitor: fastIsConstVisitor{ctx: ctx}} -} - -// Err retrieves the error field in the ConstantEvalVisitor. -func (v *ConstantEvalVisitor) Err() error { return v.err } - -// VisitPre implements the Visitor interface. -func (v *ConstantEvalVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) { - if v.err != nil { - return false, expr - } - return true, expr -} - -// VisitPost implements the Visitor interface. -func (v *ConstantEvalVisitor) VisitPost(expr Expr) Expr { - if v.err != nil { - return expr - } - - typedExpr, ok := expr.(TypedExpr) - if !ok || !v.isConst(expr) { - return expr - } - - value, err := typedExpr.Eval(v.ctx) - if err != nil { - // Ignore any errors here (e.g. division by zero), so they can happen - // during execution where they are correctly handled. Note that in some - // cases we might not even get an error (if this particular expression - // does not get evaluated when the query runs, e.g. it's inside a CASE). - return expr - } - if value == DNull { - // We don't want to return an expression that has a different type; cast - // the NULL if necessary. - return ReType(DNull, typedExpr.ResolvedType()) - } - return value -} - -func (v *ConstantEvalVisitor) isConst(expr Expr) bool { - return v.fastIsConstVisitor.run(expr) -} diff --git a/postgres/parser/sem/tree/constants.go b/postgres/parser/sem/tree/constants.go deleted file mode 100644 index c7058ce1d8..0000000000 --- a/postgres/parser/sem/tree/constants.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -// NoColumnIdx is a special value that can be used as a "column index" to -// indicate that the column is not present. -const NoColumnIdx = -1 diff --git a/postgres/parser/sem/tree/datum.go b/postgres/parser/sem/tree/datum.go index 1d5cd9a212..5f40be8061 100644 --- a/postgres/parser/sem/tree/datum.go +++ b/postgres/parser/sem/tree/datum.go @@ -25,13 +25,9 @@ package tree import ( - "bytes" - "fmt" "math" "math/big" - "net" "regexp" - "sort" "strconv" "strings" "time" @@ -52,7 +48,6 @@ import ( "github.com/dolthub/doltgresql/postgres/parser/pgcode" "github.com/dolthub/doltgresql/postgres/parser/pgdate" "github.com/dolthub/doltgresql/postgres/parser/pgerror" - "github.com/dolthub/doltgresql/postgres/parser/roachpb" "github.com/dolthub/doltgresql/postgres/parser/stringencoding" "github.com/dolthub/doltgresql/postgres/parser/timeofday" "github.com/dolthub/doltgresql/postgres/parser/timetz" @@ -75,9 +70,6 @@ var ( // DNull is the NULL Datum. DNull Datum = dNull{} - // DZero is the zero-valued integer Datum. - DZero = NewDInt(0) - // DTimeMaxTimeRegex is a compiled regex for parsing the 24:00 time value. DTimeMaxTimeRegex = regexp.MustCompile(`^([0-9-]*(\s|T))?\s*24:00(:00(.0+)?)?\s*$`) @@ -108,58 +100,6 @@ type Datum interface { // fmtFlags.disambiguateDatumTypes. AmbiguousFormat() bool - // Compare returns -1 if the receiver is less than other, 0 if receiver is - // equal to other and +1 if receiver is greater than other. - Compare(ctx *EvalContext, other Datum) int - - // Prev returns the previous datum and true, if one exists, or nil and false. - // The previous datum satisfies the following definition: if the receiver is - // "b" and the returned datum is "a", then for every compatible datum "x", it - // holds that "x < b" is true if and only if "x <= a" is true. - // - // The return value is undefined if IsMin(_ *EvalContext) returns true. - // - // TODO(#12022): for DTuple, the contract is actually that "x < b" (SQL order, - // where NULL < x is unknown for all x) is true only if "x <= a" - // (.Compare/encoding order, where NULL <= x is true for all x) is true. This - // is okay for now: the returned datum is used only to construct a span, which - // uses .Compare/encoding order and is guaranteed to be large enough by this - // weaker contract. The original filter expression is left in place to catch - // false positives. - Prev(ctx *EvalContext) (Datum, bool) - - // IsMin returns true if the datum is equal to the minimum value the datum - // type can hold. - IsMin(ctx *EvalContext) bool - - // Next returns the next datum and true, if one exists, or nil and false - // otherwise. The next datum satisfies the following definition: if the - // receiver is "a" and the returned datum is "b", then for every compatible - // datum "x", it holds that "x > a" is true if and only if "x >= b" is true. - // - // The return value is undefined if IsMax(_ *EvalContext) returns true. - // - // TODO(#12022): for DTuple, the contract is actually that "x > a" (SQL order, - // where x > NULL is unknown for all x) is true only if "x >= b" - // (.Compare/encoding order, where x >= NULL is true for all x) is true. This - // is okay for now: the returned datum is used only to construct a span, which - // uses .Compare/encoding order and is guaranteed to be large enough by this - // weaker contract. The original filter expression is left in place to catch - // false positives. - Next(ctx *EvalContext) (Datum, bool) - - // IsMax returns true if the datum is equal to the maximum value the datum - // type can hold. - IsMax(ctx *EvalContext) bool - - // Max returns the upper value and true, if one exists, otherwise - // nil and false. Used By Prev(). - Max(ctx *EvalContext) (Datum, bool) - - // Min returns the lower value, if one exists, otherwise nil and - // false. Used by Next(). - Min(ctx *EvalContext) (Datum, bool) - // Size returns a lower bound on the total size of the receiver in bytes, // including memory that is pointed at (even if shared between Datum // instances) but excluding allocation overhead. @@ -171,13 +111,6 @@ type Datum interface { // Datums is a slice of Datum values. type Datums []Datum -const ( - // SizeOfDatum is the memory size of a Datum reference. - SizeOfDatum = int64(unsafe.Sizeof(Datum(nil))) - // SizeOfDatums is the memory size of a Datum slice. - SizeOfDatums = int64(unsafe.Sizeof(Datums(nil))) -) - // Len returns the number of Datum values. func (d Datums) Len() int { return len(d) } @@ -193,52 +126,6 @@ func (d *Datums) Format(ctx *FmtCtx) { ctx.WriteByte(')') } -// Compare does a lexicographical comparison and returns -1 if the receiver -// is less than other, 0 if receiver is equal to other and +1 if receiver is -// greater than other. -func (d Datums) Compare(evalCtx *EvalContext, other Datums) int { - if len(d) == 0 { - panic(errors.AssertionFailedf("empty Datums being compared to other")) - } - - for i := range d { - if i >= len(other) { - return 1 - } - - compareDatum := d[i].Compare(evalCtx, other[i]) - if compareDatum != 0 { - return compareDatum - } - } - - if len(d) < len(other) { - return -1 - } - return 0 -} - -// IsDistinctFrom checks to see if two datums are distinct from each other. Any -// change in value is considered distinct, however, a NULL value is NOT -// considered distinct from another NULL value. -func (d Datums) IsDistinctFrom(evalCtx *EvalContext, other Datums) bool { - if len(d) != len(other) { - return true - } - for i, val := range d { - if val == DNull { - if other[i] != DNull { - return true - } - } else { - if val.Compare(evalCtx, other[i]) != 0 { - return true - } - } - } - return false -} - // CompositeDatum is a Datum that may require composite encoding in // indexes. Any Datum implementing this interface must also add itself to // colinfo.HasCompositeKeyEncoding. @@ -261,27 +148,6 @@ func MakeDBool(d DBool) *DBool { return DBoolFalse } -// MustBeDBool attempts to retrieve a DBool from an Expr, panicking if the -// assertion fails. -func MustBeDBool(e Expr) DBool { - b, ok := AsDBool(e) - if !ok { - panic(errors.AssertionFailedf("expected *DBool, found %T", e)) - } - return b -} - -// AsDBool attempts to retrieve a *DBool from an Expr, returning a *DBool and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions. -func AsDBool(e Expr) (DBool, bool) { - switch t := e.(type) { - case *DBool: - return *t, true - } - return false, false -} - // makeParseError returns a parse error using the provided string and type. An // optional error can be provided, which will be appended to the end of the // error string. @@ -294,11 +160,6 @@ func makeParseError(s string, typ *types.T, err error) error { "could not parse %q as type %s", s, typ) } -func makeUnsupportedComparisonMessage(d1, d2 Datum) error { - return errors.AssertionFailedWithDepthf(1, - "unsupported comparison: %s to %s", errors.Safe(d1.ResolvedType()), errors.Safe(d2.ResolvedType())) -} - func isCaseInsensitivePrefix(prefix, s string) bool { if len(prefix) > len(s) { return false @@ -396,76 +257,11 @@ func ParseDIPAddrFromINetString(s string) (*DIPAddr, error) { return &d, nil } -// GetBool gets DBool or an error (also treats NULL as false, not an error). -func GetBool(d Datum) (DBool, error) { - if v, ok := d.(*DBool); ok { - return *v, nil - } - if d == DNull { - return DBool(false), nil - } - return false, errors.AssertionFailedf("cannot convert %s to type %s", d.ResolvedType(), types.Bool) -} - // ResolvedType implements the TypedExpr interface. func (*DBool) ResolvedType() *types.T { return types.Bool } -// Compare implements the Datum interface. -func (d *DBool) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DBool) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - return CompareBools(bool(*d), bool(*v)) -} - -// CompareBools compares the input bools according to the SQL comparison rules. -func CompareBools(d, v bool) int { - if !d && v { - return -1 - } - if d && !v { - return 1 - } - return 0 -} - -// Prev implements the Datum interface. -func (*DBool) Prev(_ *EvalContext) (Datum, bool) { - return DBoolFalse, true -} - -// Next implements the Datum interface. -func (*DBool) Next(_ *EvalContext) (Datum, bool) { - return DBoolTrue, true -} - -// IsMax implements the Datum interface. -func (d *DBool) IsMax(_ *EvalContext) bool { - return bool(*d) -} - -// IsMin implements the Datum interface. -func (d *DBool) IsMin(_ *EvalContext) bool { - return !bool(*d) -} - -// Min implements the Datum interface. -func (d *DBool) Min(_ *EvalContext) (Datum, bool) { - return DBoolFalse, true -} - -// Max implements the Datum interface. -func (d *DBool) Max(_ *EvalContext) (Datum, bool) { - return DBoolTrue, true -} - // AmbiguousFormat implements the Datum interface. func (*DBool) AmbiguousFormat() bool { return false } @@ -503,38 +299,6 @@ func ParseDBitArray(s string) (*DBitArray, error) { return &a, nil } -// NewDBitArray returns a DBitArray. -func NewDBitArray(bitLen uint) *DBitArray { - a := MakeDBitArray(bitLen) - return &a -} - -// MakeDBitArray returns a DBitArray. -func MakeDBitArray(bitLen uint) DBitArray { - return DBitArray{BitArray: utils.MakeZeroBitArray(bitLen)} -} - -// MustBeDBitArray attempts to retrieve a DBitArray from an Expr, panicking if the -// assertion fails. -func MustBeDBitArray(e Expr) *DBitArray { - b, ok := AsDBitArray(e) - if !ok { - panic(errors.AssertionFailedf("expected *DBitArray, found %T", e)) - } - return b -} - -// AsDBitArray attempts to retrieve a *DBitArray from an Expr, returning a *DBitArray and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions. -func AsDBitArray(e Expr) (*DBitArray, bool) { - switch t := e.(type) { - case *DBitArray: - return t, true - } - return nil, false -} - var errCannotCastNegativeIntToBitArray = pgerror.Newf(pgcode.CannotCoerce, "cannot cast negative integer to bit varying with unbounded width") @@ -570,52 +334,6 @@ func (*DBitArray) ResolvedType() *types.T { return types.VarBit } -// Compare implements the Datum interface. -func (d *DBitArray) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DBitArray) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - return utils.Compare(d.BitArray, v.BitArray) -} - -// Prev implements the Datum interface. -func (d *DBitArray) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DBitArray) Next(_ *EvalContext) (Datum, bool) { - a := utils.Next(d.BitArray) - return &DBitArray{BitArray: a}, true -} - -// IsMax implements the Datum interface. -func (d *DBitArray) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DBitArray) IsMin(_ *EvalContext) bool { - return d.BitArray.IsEmpty() -} - -var bitArrayZero = NewDBitArray(0) - -// Min implements the Datum interface. -func (d *DBitArray) Min(_ *EvalContext) (Datum, bool) { - return bitArrayZero, true -} - -// Max implements the Datum interface. -func (d *DBitArray) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - // AmbiguousFormat implements the Datum interface. func (*DBitArray) AmbiguousFormat() bool { return false } @@ -659,94 +377,11 @@ func ParseDInt(s string) (*DInt, error) { return NewDInt(DInt(i)), nil } -// AsDInt attempts to retrieve a DInt from an Expr, returning a DInt and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions wherever a *DInt wrapped by a -// *DOidWrapper is possible. -func AsDInt(e Expr) (DInt, bool) { - switch t := e.(type) { - case *DInt: - return *t, true - case *DOidWrapper: - return AsDInt(t.Wrapped) - } - return 0, false -} - -// MustBeDInt attempts to retrieve a DInt from an Expr, panicking if the -// assertion fails. -func MustBeDInt(e Expr) DInt { - i, ok := AsDInt(e) - if !ok { - panic(errors.AssertionFailedf("expected *DInt, found %T", e)) - } - return i -} - // ResolvedType implements the TypedExpr interface. func (*DInt) ResolvedType() *types.T { return types.Int } -// Compare implements the Datum interface. -func (d *DInt) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - var v DInt - switch t := UnwrapDatum(ctx, other).(type) { - case *DInt: - v = *t - case *DFloat, *DDecimal: - return -t.Compare(ctx, d) - case *DOid: - v = t.DInt - default: - panic(makeUnsupportedComparisonMessage(d, other)) - } - if *d < v { - return -1 - } - if *d > v { - return 1 - } - return 0 -} - -// Prev implements the Datum interface. -func (d *DInt) Prev(_ *EvalContext) (Datum, bool) { - return NewDInt(*d - 1), true -} - -// Next implements the Datum interface. -func (d *DInt) Next(_ *EvalContext) (Datum, bool) { - return NewDInt(*d + 1), true -} - -// IsMax implements the Datum interface. -func (d *DInt) IsMax(_ *EvalContext) bool { - return *d == math.MaxInt64 -} - -// IsMin implements the Datum interface. -func (d *DInt) IsMin(_ *EvalContext) bool { - return *d == math.MinInt64 -} - -var dMaxInt = NewDInt(math.MaxInt64) -var dMinInt = NewDInt(math.MinInt64) - -// Max implements the Datum interface. -func (d *DInt) Max(_ *EvalContext) (Datum, bool) { - return dMaxInt, true -} - -// Min implements the Datum interface. -func (d *DInt) Min(_ *EvalContext) (Datum, bool) { - return dMinInt, true -} - // AmbiguousFormat implements the Datum interface. func (*DInt) AmbiguousFormat() bool { return true } @@ -774,16 +409,6 @@ func (d *DInt) Size() uintptr { // DFloat is the float Datum. type DFloat float64 -// MustBeDFloat attempts to retrieve a DFloat from an Expr, panicking if the -// assertion fails. -func MustBeDFloat(e Expr) DFloat { - switch t := e.(type) { - case *DFloat: - return *t - } - panic(errors.AssertionFailedf("expected *DFloat, found %T", e)) -} - // NewDFloat is a helper routine to create a *DFloat initialized from its // argument. func NewDFloat(d DFloat) *DFloat { @@ -805,91 +430,6 @@ func (*DFloat) ResolvedType() *types.T { return types.Float } -// Compare implements the Datum interface. -func (d *DFloat) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - var v DFloat - switch t := UnwrapDatum(ctx, other).(type) { - case *DFloat: - v = *t - case *DInt: - v = DFloat(MustBeDInt(t)) - case *DDecimal: - return -t.Compare(ctx, d) - default: - panic(makeUnsupportedComparisonMessage(d, other)) - } - if *d < v { - return -1 - } - if *d > v { - return 1 - } - // NaN sorts before non-NaN (#10109). - if *d == v { - return 0 - } - if math.IsNaN(float64(*d)) { - if math.IsNaN(float64(v)) { - return 0 - } - return -1 - } - return 1 -} - -// Prev implements the Datum interface. -func (d *DFloat) Prev(_ *EvalContext) (Datum, bool) { - f := float64(*d) - if math.IsNaN(f) { - return nil, false - } - if f == math.Inf(-1) { - return dNaNFloat, true - } - return NewDFloat(DFloat(math.Nextafter(f, math.Inf(-1)))), true -} - -// Next implements the Datum interface. -func (d *DFloat) Next(_ *EvalContext) (Datum, bool) { - f := float64(*d) - if math.IsNaN(f) { - return dNegInfFloat, true - } - if f == math.Inf(+1) { - return nil, false - } - return NewDFloat(DFloat(math.Nextafter(f, math.Inf(+1)))), true -} - -var dZeroFloat = NewDFloat(0.0) -var dPosInfFloat = NewDFloat(DFloat(math.Inf(+1))) -var dNegInfFloat = NewDFloat(DFloat(math.Inf(-1))) -var dNaNFloat = NewDFloat(DFloat(math.NaN())) - -// IsMax implements the Datum interface. -func (d *DFloat) IsMax(_ *EvalContext) bool { - return *d == *dPosInfFloat -} - -// IsMin implements the Datum interface. -func (d *DFloat) IsMin(_ *EvalContext) bool { - return math.IsNaN(float64(*d)) -} - -// Max implements the Datum interface. -func (d *DFloat) Max(_ *EvalContext) (Datum, bool) { - return dPosInfFloat, true -} - -// Min implements the Datum interface. -func (d *DFloat) Min(_ *EvalContext) (Datum, bool) { - return dNaNFloat, true -} - // AmbiguousFormat implements the Datum interface. func (*DFloat) AmbiguousFormat() bool { return true } @@ -938,16 +478,6 @@ type DDecimal struct { apd.Decimal } -// MustBeDDecimal attempts to retrieve a DDecimal from an Expr, panicking if the -// assertion fails. -func MustBeDDecimal(e Expr) DDecimal { - switch t := e.(type) { - case *DDecimal: - return *t - } - panic(errors.AssertionFailedf("expected *DDecimal, found %T", e)) -} - // ParseDDecimal parses and returns the *DDecimal Datum value represented by the // provided string, or an error if parsing is unsuccessful. func ParseDDecimal(s string) (*DDecimal, error) { @@ -985,76 +515,6 @@ func (*DDecimal) ResolvedType() *types.T { return types.Decimal } -// Compare implements the Datum interface. -func (d *DDecimal) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v := ctx.getTmpDec() - switch t := UnwrapDatum(ctx, other).(type) { - case *DDecimal: - v = &t.Decimal - case *DInt: - v.SetInt64(int64(*t)) - case *DFloat: - if _, err := v.SetFloat64(float64(*t)); err != nil { - panic(errors.NewAssertionErrorWithWrappedErrf(err, "decimal compare, unexpected error")) - } - default: - panic(makeUnsupportedComparisonMessage(d, other)) - } - return CompareDecimals(&d.Decimal, v) -} - -// CompareDecimals compares 2 apd.Decimals according to the SQL comparison -// rules, making sure that NaNs sort first. -func CompareDecimals(d *apd.Decimal, v *apd.Decimal) int { - // NaNs sort first in SQL. - if dn, vn := d.Form == apd.NaN, v.Form == apd.NaN; dn && !vn { - return -1 - } else if !dn && vn { - return 1 - } else if dn && vn { - return 0 - } - return d.Cmp(v) -} - -// Prev implements the Datum interface. -func (d *DDecimal) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DDecimal) Next(_ *EvalContext) (Datum, bool) { - return nil, false -} - -var dZeroDecimal = &DDecimal{Decimal: apd.Decimal{}} -var dPosInfDecimal = &DDecimal{Decimal: apd.Decimal{Form: apd.Infinite, Negative: false}} -var dNaNDecimal = &DDecimal{Decimal: apd.Decimal{Form: apd.NaN}} - -// IsMax implements the Datum interface. -func (d *DDecimal) IsMax(_ *EvalContext) bool { - return d.Form == apd.Infinite && !d.Negative -} - -// IsMin implements the Datum interface. -func (d *DDecimal) IsMin(_ *EvalContext) bool { - return d.Form == apd.NaN -} - -// Max implements the Datum interface. -func (d *DDecimal) Max(_ *EvalContext) (Datum, bool) { - return dPosInfDecimal, true -} - -// Min implements the Datum interface. -func (d *DDecimal) Min(_ *EvalContext) (Datum, bool) { - return dNaNDecimal, true -} - // AmbiguousFormat implements the Datum interface. func (*DDecimal) AmbiguousFormat() bool { return true } @@ -1119,86 +579,11 @@ func NewDString(d string) *DString { return &r } -// AsDString attempts to retrieve a DString from an Expr, returning a DString and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions wherever a *DString wrapped by a -// *DOidWrapper is possible. -func AsDString(e Expr) (DString, bool) { - switch t := e.(type) { - case *DString: - return *t, true - case *DOidWrapper: - return AsDString(t.Wrapped) - } - return "", false -} - -// MustBeDString attempts to retrieve a DString from an Expr, panicking if the -// assertion fails. -func MustBeDString(e Expr) DString { - i, ok := AsDString(e) - if !ok { - panic(errors.AssertionFailedf("expected *DString, found %T", e)) - } - return i -} - // ResolvedType implements the TypedExpr interface. func (*DString) ResolvedType() *types.T { return types.String } -// Compare implements the Datum interface. -func (d *DString) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DString) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - if *d < *v { - return -1 - } - if *d > *v { - return 1 - } - return 0 -} - -// Prev implements the Datum interface. -func (d *DString) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DString) Next(_ *EvalContext) (Datum, bool) { - return NewDString(string(roachpb.Key(*d).Next())), true -} - -// IsMax implements the Datum interface. -func (*DString) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DString) IsMin(_ *EvalContext) bool { - return len(*d) == 0 -} - -var dEmptyString = NewDString("") - -// Min implements the Datum interface. -func (d *DString) Min(_ *EvalContext) (Datum, bool) { - return dEmptyString, true -} - -// Max implements the Datum interface. -func (d *DString) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - // AmbiguousFormat implements the Datum interface. func (*DString) AmbiguousFormat() bool { return true } @@ -1260,25 +645,6 @@ func (env *CollationEnvironment) getCacheEntry( return entry, nil } -// NewDCollatedString is a helper routine to create a *DCollatedString. Panics -// if locale is invalid. Not safe for concurrent use. -func NewDCollatedString( - contents string, locale string, env *CollationEnvironment, -) (*DCollatedString, error) { - entry, err := env.getCacheEntry(locale) - if err != nil { - return nil, err - } - if env.buffer == nil { - env.buffer = &collate.Buffer{} - } - key := entry.collator.KeyFromString(env.buffer, contents) - d := DCollatedString{contents, entry.locale, make([]byte, len(key))} - copy(d.Key, key) - env.buffer.Reset() - return &d, nil -} - // AmbiguousFormat implements the Datum interface. func (*DCollatedString) AmbiguousFormat() bool { return false } @@ -1294,52 +660,9 @@ func (d *DCollatedString) ResolvedType() *types.T { return types.MakeCollatedString(types.String, d.Locale) } -// Compare implements the Datum interface. -func (d *DCollatedString) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DCollatedString) - if !ok || d.Locale != v.Locale { - panic(makeUnsupportedComparisonMessage(d, other)) - } - return bytes.Compare(d.Key, v.Key) -} - -// Prev implements the Datum interface. -func (d *DCollatedString) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DCollatedString) Next(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// IsMax implements the Datum interface. -func (*DCollatedString) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DCollatedString) IsMin(_ *EvalContext) bool { - return d.Contents == "" -} - -// Min implements the Datum interface. -func (d *DCollatedString) Min(_ *EvalContext) (Datum, bool) { - return &DCollatedString{"", d.Locale, nil}, true -} - -// Max implements the Datum interface. -func (d *DCollatedString) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Size implements the Datum interface. -func (d *DCollatedString) Size() uintptr { - return unsafe.Sizeof(*d) + uintptr(len(d.Contents)) + uintptr(len(d.Locale)) + uintptr(len(d.Key)) +// Size implements the Datum interface. +func (d *DCollatedString) Size() uintptr { + return unsafe.Sizeof(*d) + uintptr(len(d.Contents)) + uintptr(len(d.Locale)) + uintptr(len(d.Key)) } // IsComposite implements the CompositeDatum interface. @@ -1357,81 +680,11 @@ func NewDBytes(d DBytes) *DBytes { return &d } -// MustBeDBytes attempts to convert an Expr into a DBytes, panicking if unsuccessful. -func MustBeDBytes(e Expr) DBytes { - i, ok := AsDBytes(e) - if !ok { - panic(errors.AssertionFailedf("expected *DBytes, found %T", e)) - } - return i -} - -// AsDBytes attempts to convert an Expr into a DBytes, returning a flag indicating -// whether it was successful. -func AsDBytes(e Expr) (DBytes, bool) { - switch t := e.(type) { - case *DBytes: - return *t, true - } - return "", false -} - // ResolvedType implements the TypedExpr interface. func (*DBytes) ResolvedType() *types.T { return types.Bytes } -// Compare implements the Datum interface. -func (d *DBytes) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DBytes) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - if *d < *v { - return -1 - } - if *d > *v { - return 1 - } - return 0 -} - -// Prev implements the Datum interface. -func (d *DBytes) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DBytes) Next(_ *EvalContext) (Datum, bool) { - return NewDBytes(DBytes(roachpb.Key(*d).Next())), true -} - -// IsMax implements the Datum interface. -func (*DBytes) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DBytes) IsMin(_ *EvalContext) bool { - return len(*d) == 0 -} - -var dEmptyBytes = NewDBytes(DBytes("")) - -// Min implements the Datum interface. -func (d *DBytes) Min(_ *EvalContext) (Datum, bool) { - return dEmptyBytes, true -} - -// Max implements the Datum interface. -func (d *DBytes) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - // AmbiguousFormat implements the Datum interface. func (*DBytes) AmbiguousFormat() bool { return true } @@ -1486,64 +739,6 @@ func (*DUuid) ResolvedType() *types.T { return types.Uuid } -// Compare implements the Datum interface. -func (d *DUuid) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DUuid) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - return bytes.Compare(d.GetBytes(), v.GetBytes()) -} - -func (d *DUuid) equal(other *DUuid) bool { - return bytes.Equal(d.GetBytes(), other.GetBytes()) -} - -// Prev implements the Datum interface. -func (d *DUuid) Prev(_ *EvalContext) (Datum, bool) { - i := d.ToUint128() - u := uuid.FromUint128(i.Sub(1)) - return NewDUuid(DUuid{u}), true -} - -// Next implements the Datum interface. -func (d *DUuid) Next(_ *EvalContext) (Datum, bool) { - i := d.ToUint128() - u := uuid.FromUint128(i.Add(1)) - return NewDUuid(DUuid{u}), true -} - -// IsMax implements the Datum interface. -func (d *DUuid) IsMax(_ *EvalContext) bool { - return d.equal(DMaxUUID) -} - -// IsMin implements the Datum interface. -func (d *DUuid) IsMin(_ *EvalContext) bool { - return d.equal(DMinUUID) -} - -// DMinUUID is the min UUID. -var DMinUUID = NewDUuid(DUuid{uuid.UUID{}}) - -// DMaxUUID is the max UUID. -var DMaxUUID = NewDUuid(DUuid{uuid.UUID{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}) - -// Min implements the Datum interface. -func (*DUuid) Min(_ *EvalContext) (Datum, bool) { - return DMinUUID, true -} - -// Max implements the Datum interface. -func (*DUuid) Max(_ *EvalContext) (Datum, bool) { - return DMaxUUID, true -} - // AmbiguousFormat implements the Datum interface. func (*DUuid) AmbiguousFormat() bool { return true } @@ -1570,140 +765,15 @@ type DIPAddr struct { ipaddr.IPAddr } -// NewDIPAddr is a helper routine to create a *DIPAddr initialized from its -// argument. -func NewDIPAddr(d DIPAddr) *DIPAddr { - return &d -} - -// AsDIPAddr attempts to retrieve a *DIPAddr from an Expr, returning a *DIPAddr and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions wherever a *DIPAddr wrapped by a -// *DOidWrapper is possible. -func AsDIPAddr(e Expr) (DIPAddr, bool) { - switch t := e.(type) { - case *DIPAddr: - return *t, true - case *DOidWrapper: - return AsDIPAddr(t.Wrapped) - } - return DIPAddr{}, false -} - -// MustBeDIPAddr attempts to retrieve a DIPAddr from an Expr, panicking if the -// assertion fails. -func MustBeDIPAddr(e Expr) DIPAddr { - i, ok := AsDIPAddr(e) - if !ok { - panic(errors.AssertionFailedf("expected *DIPAddr, found %T", e)) - } - return i -} - // ResolvedType implements the TypedExpr interface. func (*DIPAddr) ResolvedType() *types.T { return types.INet } -// Compare implements the Datum interface. -func (d *DIPAddr) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DIPAddr) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - - return d.IPAddr.Compare(&v.IPAddr) -} - func (d DIPAddr) equal(other *DIPAddr) bool { return d.IPAddr.Equal(&other.IPAddr) } -// Prev implements the Datum interface. -func (d *DIPAddr) Prev(_ *EvalContext) (Datum, bool) { - // We will do one of the following to get the Prev IPAddr: - // - Decrement IP address if we won't underflow the IP. - // - Decrement mask and set the IP to max in family if we will underflow. - // - Jump down from IPv6 to IPv4 if we will underflow both IP and mask. - if d.Family == ipaddr.IPv6family && d.Addr.Equal(dIPv6min) { - if d.Mask == 0 { - // Jump down IP family. - return dMaxIPv4Addr, true - } - // Decrease mask size, wrap IPv6 IP address. - return NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv6family, Addr: dIPv6max, Mask: d.Mask - 1}}), true - } else if d.Family == ipaddr.IPv4family && d.Addr.Equal(dIPv4min) { - // Decrease mask size, wrap IPv4 IP address. - return NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv4family, Addr: dIPv4max, Mask: d.Mask - 1}}), true - } - // Decrement IP address. - return NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: d.Family, Addr: d.Addr.Sub(1), Mask: d.Mask}}), true -} - -// Next implements the Datum interface. -func (d *DIPAddr) Next(_ *EvalContext) (Datum, bool) { - // We will do one of a few things to get the Next IP address: - // - Increment IP address if we won't overflow the IP. - // - Increment mask and set the IP to min in family if we will overflow. - // - Jump up from IPv4 to IPv6 if we will overflow both IP and mask. - if d.Family == ipaddr.IPv4family && d.Addr.Equal(dIPv4max) { - if d.Mask == 32 { - // Jump up IP family. - return dMinIPv6Addr, true - } - // Increase mask size, wrap IPv4 IP address. - return NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv4family, Addr: dIPv4min, Mask: d.Mask + 1}}), true - } else if d.Family == ipaddr.IPv6family && d.Addr.Equal(dIPv6max) { - // Increase mask size, wrap IPv6 IP address. - return NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv6family, Addr: dIPv6min, Mask: d.Mask + 1}}), true - } - // Increment IP address. - return NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: d.Family, Addr: d.Addr.Add(1), Mask: d.Mask}}), true -} - -// IsMax implements the Datum interface. -func (d *DIPAddr) IsMax(_ *EvalContext) bool { - return d.equal(DMaxIPAddr) -} - -// IsMin implements the Datum interface. -func (d *DIPAddr) IsMin(_ *EvalContext) bool { - return d.equal(DMinIPAddr) -} - -// dIPv4 and dIPv6 min and maxes use ParseIP because the actual byte constant is -// no equal to solely zeros or ones. For IPv4 there is a 0xffff prefix. Without -// this prefix this makes IP arithmetic invalid. -var dIPv4min = ipaddr.Addr(utils.FromBytes([]byte(net.ParseIP("0.0.0.0")))) -var dIPv4max = ipaddr.Addr(utils.FromBytes([]byte(net.ParseIP("255.255.255.255")))) -var dIPv6min = ipaddr.Addr(utils.FromBytes([]byte(net.ParseIP("::")))) -var dIPv6max = ipaddr.Addr(utils.FromBytes([]byte(net.ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")))) - -// dMaxIPv4Addr and dMinIPv6Addr are used as global constants to prevent extra -// heap extra allocation -var dMaxIPv4Addr = NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv4family, Addr: dIPv4max, Mask: 32}}) -var dMinIPv6Addr = NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv6family, Addr: dIPv6min, Mask: 0}}) - -// DMinIPAddr is the min DIPAddr. -var DMinIPAddr = NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv4family, Addr: dIPv4min, Mask: 0}}) - -// DMaxIPAddr is the max DIPaddr. -var DMaxIPAddr = NewDIPAddr(DIPAddr{ipaddr.IPAddr{Family: ipaddr.IPv6family, Addr: dIPv6max, Mask: 128}}) - -// Min implements the Datum interface. -func (*DIPAddr) Min(_ *EvalContext) (Datum, bool) { - return DMinIPAddr, true -} - -// Max implements the Datum interface. -func (*DIPAddr) Max(_ *EvalContext) (Datum, bool) { - return DMaxIPAddr, true -} - // AmbiguousFormat implements the Datum interface. func (*DIPAddr) AmbiguousFormat() bool { return true @@ -1739,17 +809,6 @@ func NewDDate(d pgdate.Date) *DDate { return &DDate{Date: d} } -// MakeDDate makes a DDate from a pgdate.Date. -func MakeDDate(d pgdate.Date) DDate { - return DDate{Date: d} -} - -// NewDDateFromTime constructs a *DDate from a time.Time. -func NewDDateFromTime(t time.Time) (*DDate, error) { - d, err := pgdate.MakeDateFromTime(t) - return NewDDate(d), err -} - // ParseTimeContext provides the information necessary for // parsing dates, times, and timestamps. A nil value is generally // acceptable and will result in reasonable defaults being applied. @@ -1761,17 +820,8 @@ type ParseTimeContext interface { GetRelativeParseTime() time.Time } -var _ ParseTimeContext = &EvalContext{} var _ ParseTimeContext = &simpleParseTimeContext{} -// NewParseTimeContext constructs a ParseTimeContext that returns -// the given values. -func NewParseTimeContext(relativeParseTime time.Time) ParseTimeContext { - return &simpleParseTimeContext{ - RelativeParseTime: relativeParseTime, - } -} - type simpleParseTimeContext struct { RelativeParseTime time.Time } @@ -1806,87 +856,6 @@ func (*DDate) ResolvedType() *types.T { return types.Date } -// Compare implements the Datum interface. -func (d *DDate) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - var v DDate - switch t := UnwrapDatum(ctx, other).(type) { - case *DDate: - v = *t - case *DTimestamp, *DTimestampTZ: - return compareTimestamps(ctx, d, other) - default: - panic(makeUnsupportedComparisonMessage(d, other)) - } - return d.Date.Compare(v.Date) -} - -var ( - epochDate, _ = pgdate.MakeDateFromPGEpoch(0) - dEpochDate = NewDDate(epochDate) - dMaxDate = NewDDate(pgdate.PosInfDate) - dMinDate = NewDDate(pgdate.NegInfDate) - dLowDate = NewDDate(pgdate.LowDate) - dHighDate = NewDDate(pgdate.HighDate) -) - -// Prev implements the Datum interface. -func (d *DDate) Prev(_ *EvalContext) (Datum, bool) { - switch d.Date { - case pgdate.PosInfDate: - return dHighDate, true - case pgdate.LowDate: - return dMinDate, true - case pgdate.NegInfDate: - return nil, false - } - n, err := d.AddDays(-1) - if err != nil { - return nil, false - } - return NewDDate(n), true -} - -// Next implements the Datum interface. -func (d *DDate) Next(_ *EvalContext) (Datum, bool) { - switch d.Date { - case pgdate.NegInfDate: - return dLowDate, true - case pgdate.HighDate: - return dMaxDate, true - case pgdate.PosInfDate: - return nil, false - } - n, err := d.AddDays(1) - if err != nil { - return nil, false - } - return NewDDate(n), true -} - -// IsMax implements the Datum interface. -func (d *DDate) IsMax(_ *EvalContext) bool { - return d.Date == pgdate.PosInfDate -} - -// IsMin implements the Datum interface. -func (d *DDate) IsMin(_ *EvalContext) bool { - return d.Date == pgdate.NegInfDate -} - -// Max implements the Datum interface. -func (d *DDate) Max(_ *EvalContext) (Datum, bool) { - return dMaxDate, true -} - -// Min implements the Datum interface. -func (d *DDate) Min(_ *EvalContext) (Datum, bool) { - return dMinDate, true -} - // AmbiguousFormat implements the Datum interface. func (*DDate) AmbiguousFormat() bool { return true } @@ -1950,61 +919,11 @@ func (*DTime) ResolvedType() *types.T { return types.Time } -// Compare implements the Datum interface. -func (d *DTime) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - return compareTimestamps(ctx, d, other) -} - -// Prev implements the Datum interface. -func (d *DTime) Prev(ctx *EvalContext) (Datum, bool) { - if d.IsMin(ctx) { - return nil, false - } - prev := *d - 1 - return &prev, true -} - // Round returns a new DTime to the specified precision. func (d *DTime) Round(precision time.Duration) *DTime { return MakeDTime(timeofday.TimeOfDay(*d).Round(precision)) } -// Next implements the Datum interface. -func (d *DTime) Next(ctx *EvalContext) (Datum, bool) { - if d.IsMax(ctx) { - return nil, false - } - next := *d + 1 - return &next, true -} - -var dTimeMin = MakeDTime(timeofday.Min) -var dTimeMax = MakeDTime(timeofday.Max) - -// IsMax implements the Datum interface. -func (d *DTime) IsMax(_ *EvalContext) bool { - return *d == *dTimeMax -} - -// IsMin implements the Datum interface. -func (d *DTime) IsMin(_ *EvalContext) bool { - return *d == *dTimeMin -} - -// Max implements the Datum interface. -func (d *DTime) Max(_ *EvalContext) (Datum, bool) { - return dTimeMax, true -} - -// Min implements the Datum interface. -func (d *DTime) Min(_ *EvalContext) (Datum, bool) { - return dTimeMin, true -} - // AmbiguousFormat implements the Datum interface. func (*DTime) AmbiguousFormat() bool { return true } @@ -2031,34 +950,16 @@ type DTimeTZ struct { timetz.TimeTZ } -var ( - dZeroTimeTZ = NewDTimeTZFromOffset(timeofday.Min, 0) - // DMinTimeTZ is the min TimeTZ. - DMinTimeTZ = NewDTimeTZFromOffset(timeofday.Min, timetz.MinTimeTZOffsetSecs) - // DMaxTimeTZ is the max TimeTZ. - DMaxTimeTZ = NewDTimeTZFromOffset(timeofday.Max, timetz.MaxTimeTZOffsetSecs) -) - // NewDTimeTZ creates a DTimeTZ from a timetz.TimeTZ. func NewDTimeTZ(t timetz.TimeTZ) *DTimeTZ { return &DTimeTZ{t} } -// NewDTimeTZFromTime creates a DTimeTZ from time.Time. -func NewDTimeTZFromTime(t time.Time) *DTimeTZ { - return &DTimeTZ{timetz.MakeTimeTZFromTime(t)} -} - // NewDTimeTZFromOffset creates a DTimeTZ from a TimeOfDay and offset. func NewDTimeTZFromOffset(t timeofday.TimeOfDay, offsetSecs int32) *DTimeTZ { return &DTimeTZ{timetz.MakeTimeTZ(t, offsetSecs)} } -// NewDTimeTZFromLocation creates a DTimeTZ from a TimeOfDay and time.Location. -func NewDTimeTZFromLocation(t timeofday.TimeOfDay, loc *time.Location) *DTimeTZ { - return &DTimeTZ{timetz.MakeTimeTZFromLocation(t, loc)} -} - // ParseDTimeTZ parses and returns the *DTime Datum value represented by the // provided string, or an error if parsing is unsuccessful. // @@ -2080,56 +981,11 @@ func (*DTimeTZ) ResolvedType() *types.T { return types.TimeTZ } -// Compare implements the Datum interface. -func (d *DTimeTZ) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - return compareTimestamps(ctx, d, other) -} - -// Prev implements the Datum interface. -func (d *DTimeTZ) Prev(ctx *EvalContext) (Datum, bool) { - if d.IsMin(ctx) { - return nil, false - } - return NewDTimeTZFromOffset(d.TimeOfDay-1, d.OffsetSecs), true -} - -// Next implements the Datum interface. -func (d *DTimeTZ) Next(ctx *EvalContext) (Datum, bool) { - if d.IsMax(ctx) { - return nil, false - } - return NewDTimeTZFromOffset(d.TimeOfDay+1, d.OffsetSecs), true -} - -// IsMax implements the Datum interface. -func (d *DTimeTZ) IsMax(_ *EvalContext) bool { - return d.TimeOfDay == DMaxTimeTZ.TimeOfDay && d.OffsetSecs == timetz.MaxTimeTZOffsetSecs -} - -// IsMin implements the Datum interface. -func (d *DTimeTZ) IsMin(_ *EvalContext) bool { - return d.TimeOfDay == DMinTimeTZ.TimeOfDay && d.OffsetSecs == timetz.MinTimeTZOffsetSecs -} - -// Max implements the Datum interface. -func (d *DTimeTZ) Max(_ *EvalContext) (Datum, bool) { - return DMaxTimeTZ, true -} - // Round returns a new DTimeTZ to the specified precision. func (d *DTimeTZ) Round(precision time.Duration) *DTimeTZ { return NewDTimeTZ(d.TimeTZ.Round(precision)) } -// Min implements the Datum interface. -func (d *DTimeTZ) Min(_ *EvalContext) (Datum, bool) { - return DMinTimeTZ, true -} - // AmbiguousFormat implements the Datum interface. func (*DTimeTZ) AmbiguousFormat() bool { return true } @@ -2176,8 +1032,6 @@ func MustMakeDTimestamp(t time.Time, precision time.Duration) *DTimestamp { return ret } -var dZeroTimestamp = &DTimestamp{} - // time.Time formats. const ( // TimestampTZOutputFormat is used to output all TimestampTZs. @@ -2203,30 +1057,6 @@ func ParseDTimestamp( return d, dependsOnContext, err } -// AsDTimestamp attempts to retrieve a DTimestamp from an Expr, returning a DTimestamp and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions wherever a *DTimestamp wrapped by a -// *DOidWrapper is possible. -func AsDTimestamp(e Expr) (DTimestamp, bool) { - switch t := e.(type) { - case *DTimestamp: - return *t, true - case *DOidWrapper: - return AsDTimestamp(t.Wrapped) - } - return DTimestamp{}, false -} - -// MustBeDTimestamp attempts to retrieve a DTimestamp from an Expr, panicking if the -// assertion fails. -func MustBeDTimestamp(e Expr) DTimestamp { - t, ok := AsDTimestamp(e) - if !ok { - panic(errors.AssertionFailedf("expected *DTimestamp, found %T", e)) - } - return t -} - // Round returns a new DTimestamp to the specified precision. func (d *DTimestamp) Round(precision time.Duration) (*DTimestamp, error) { return MakeDTimestamp(d.Time, precision) @@ -2237,129 +1067,6 @@ func (*DTimestamp) ResolvedType() *types.T { return types.Timestamp } -// timeFromDatumForComparison gets the time from a datum object to use -// strictly for comparison usage. -func timeFromDatumForComparison(ctx *EvalContext, d Datum) (time.Time, bool) { - d = UnwrapDatum(ctx, d) - switch t := d.(type) { - case *DDate: - ts, err := MakeDTimestampTZFromDate(ctx.GetLocation(), t) - if err != nil { - return time.Time{}, false - } - return ts.Time, true - case *DTimestampTZ: - return t.Time, true - case *DTimestamp: - // Normalize to the timezone of the context. - _, zoneOffset := t.Time.In(ctx.GetLocation()).Zone() - ts := t.Time.In(ctx.GetLocation()).Add(-time.Duration(zoneOffset) * time.Second) - return ts, true - case *DTime: - // Normalize to the timezone of the context. - toTime := timeofday.TimeOfDay(*t).ToTime() - _, zoneOffsetSecs := toTime.In(ctx.GetLocation()).Zone() - return toTime.In(ctx.GetLocation()).Add(-time.Duration(zoneOffsetSecs) * time.Second), true - case *DTimeTZ: - return t.ToTime(), true - default: - return time.Time{}, false - } -} - -func compareTimestamps(ctx *EvalContext, l Datum, r Datum) int { - lTime, lOk := timeFromDatumForComparison(ctx, l) - rTime, rOk := timeFromDatumForComparison(ctx, r) - if !lOk || !rOk { - panic(makeUnsupportedComparisonMessage(l, r)) - } - if lTime.Before(rTime) { - return -1 - } - if rTime.Before(lTime) { - return 1 - } - - // If either side is a TimeTZ, then we must compare timezones before - // when comparing. If comparing a non-TimeTZ value, and the times are - // equal, then we must compare relative to the current zone we are at. - // - // This is a special quirk of TimeTZ and does not apply to TimestampTZ, - // as TimestampTZ does not store a timezone offset and is based on - // the current zone. - _, leftIsTimeTZ := l.(*DTimeTZ) - _, rightIsTimeTZ := r.(*DTimeTZ) - - // If neither side is TimeTZ, this is always equal at this point. - if !leftIsTimeTZ && !rightIsTimeTZ { - return 0 - } - - _, zoneOffset := ctx.GetRelativeParseTime().Zone() - lOffset := int32(-zoneOffset) - rOffset := int32(-zoneOffset) - - if leftIsTimeTZ { - lOffset = l.(*DTimeTZ).OffsetSecs - } - if rightIsTimeTZ { - rOffset = r.(*DTimeTZ).OffsetSecs - } - - if lOffset > rOffset { - return 1 - } - if lOffset < rOffset { - return -1 - } - return 0 -} - -// Compare implements the Datum interface. -func (d *DTimestamp) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - return compareTimestamps(ctx, d, other) -} - -// Prev implements the Datum interface. -func (d *DTimestamp) Prev(ctx *EvalContext) (Datum, bool) { - if d.IsMin(ctx) { - return nil, false - } - return &DTimestamp{Time: d.Add(-time.Microsecond)}, true -} - -// Next implements the Datum interface. -func (d *DTimestamp) Next(ctx *EvalContext) (Datum, bool) { - if d.IsMax(ctx) { - return nil, false - } - return &DTimestamp{Time: d.Add(time.Microsecond)}, true -} - -// IsMax implements the Datum interface. -func (d *DTimestamp) IsMax(_ *EvalContext) bool { - return d.Equal(MaxSupportedTime) -} - -// IsMin implements the Datum interface. -func (d *DTimestamp) IsMin(_ *EvalContext) bool { - return d.Equal(MinSupportedTime) -} - -// Min implements the Datum interface. -func (d *DTimestamp) Min(_ *EvalContext) (Datum, bool) { - return &DTimestamp{Time: MinSupportedTime}, true -} - -// Max implements the Datum interface. -func (d *DTimestamp) Max(_ *EvalContext) (Datum, bool) { - return &DTimestamp{Time: MaxSupportedTime}, true -} - // AmbiguousFormat implements the Datum interface. func (*DTimestamp) AmbiguousFormat() bool { return true } @@ -2405,19 +1112,6 @@ func MustMakeDTimestampTZ(t time.Time, precision time.Duration) *DTimestampTZ { return ret } -// MakeDTimestampTZFromDate creates a DTimestampTZ from a DDate. -// This will be equivalent to the midnight of the given zone. -func MakeDTimestampTZFromDate(loc *time.Location, d *DDate) (*DTimestampTZ, error) { - t, err := d.ToTime() - if err != nil { - return nil, err - } - // Normalize to the correct zone. - t = t.In(loc) - _, offset := t.Zone() - return MakeDTimestampTZ(t.Add(time.Duration(-offset)*time.Second), time.Microsecond) -} - // ParseDTimestampTZ parses and returns the *DTimestampTZ Datum value represented by // the provided string in the provided location, or an error if parsing is unsuccessful. // @@ -2436,32 +1130,6 @@ func ParseDTimestampTZ( return d, dependsOnContext, err } -var dZeroTimestampTZ = &DTimestampTZ{} - -// AsDTimestampTZ attempts to retrieve a DTimestampTZ from an Expr, returning a -// DTimestampTZ and a flag signifying whether the assertion was successful. The -// function should be used instead of direct type assertions wherever a -// *DTimestamp wrapped by a *DOidWrapper is possible. -func AsDTimestampTZ(e Expr) (DTimestampTZ, bool) { - switch t := e.(type) { - case *DTimestampTZ: - return *t, true - case *DOidWrapper: - return AsDTimestampTZ(t.Wrapped) - } - return DTimestampTZ{}, false -} - -// MustBeDTimestampTZ attempts to retrieve a DTimestampTZ from an Expr, -// panicking if the assertion fails. -func MustBeDTimestampTZ(e Expr) DTimestampTZ { - t, ok := AsDTimestampTZ(e) - if !ok { - panic(errors.AssertionFailedf("expected *DTimestampTZ, found %T", e)) - } - return t -} - // Round returns a new DTimestampTZ to the specified precision. func (d *DTimestampTZ) Round(precision time.Duration) (*DTimestampTZ, error) { return MakeDTimestampTZ(d.Time, precision) @@ -2472,51 +1140,6 @@ func (*DTimestampTZ) ResolvedType() *types.T { return types.TimestampTZ } -// Compare implements the Datum interface. -func (d *DTimestampTZ) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - return compareTimestamps(ctx, d, other) -} - -// Prev implements the Datum interface. -func (d *DTimestampTZ) Prev(ctx *EvalContext) (Datum, bool) { - if d.IsMin(ctx) { - return nil, false - } - return &DTimestampTZ{Time: d.Add(-time.Microsecond)}, true -} - -// Next implements the Datum interface. -func (d *DTimestampTZ) Next(ctx *EvalContext) (Datum, bool) { - if d.IsMax(ctx) { - return nil, false - } - return &DTimestampTZ{Time: d.Add(time.Microsecond)}, true -} - -// IsMax implements the Datum interface. -func (d *DTimestampTZ) IsMax(_ *EvalContext) bool { - return d.Equal(MaxSupportedTime) -} - -// IsMin implements the Datum interface. -func (d *DTimestampTZ) IsMin(_ *EvalContext) bool { - return d.Equal(MinSupportedTime) -} - -// Min implements the Datum interface. -func (d *DTimestampTZ) Min(_ *EvalContext) (Datum, bool) { - return &DTimestampTZ{Time: MinSupportedTime}, true -} - -// Max implements the Datum interface. -func (d *DTimestampTZ) Max(_ *EvalContext) (Datum, bool) { - return &DTimestampTZ{Time: MaxSupportedTime}, true -} - // AmbiguousFormat implements the Datum interface. func (*DTimestampTZ) AmbiguousFormat() bool { return true } @@ -2538,43 +1161,11 @@ func (d *DTimestampTZ) Size() uintptr { return unsafe.Sizeof(*d) } -// stripTimeZone removes the time zone from this TimestampTZ. For example, a -// TimestampTZ '2012-01-01 12:00:00 +02:00' would become -// '2012-01-01 12:00:00'. -func (d *DTimestampTZ) stripTimeZone(ctx *EvalContext) (*DTimestamp, error) { - return d.EvalAtTimeZone(ctx, ctx.GetLocation()) -} - -// EvalAtTimeZone evaluates this TimestampTZ as if it were in the supplied -// location, returning a timestamp without a timezone. -func (d *DTimestampTZ) EvalAtTimeZone(ctx *EvalContext, loc *time.Location) (*DTimestamp, error) { - _, locOffset := d.Time.In(loc).Zone() - t := d.Time.UTC().Add(time.Duration(locOffset) * time.Second).UTC() - return MakeDTimestamp(t, time.Microsecond) -} - // DInterval is the interval Datum. type DInterval struct { duration.Duration } -// MustBeDInterval attempts to retrieve a DInterval from an Expr, panicking if the -// assertion fails. -func MustBeDInterval(e Expr) *DInterval { - switch t := e.(type) { - case *DInterval: - return t - } - panic(errors.AssertionFailedf("expected *DInterval, found %T", e)) -} - -// NewDInterval creates a new DInterval. -func NewDInterval(d duration.Duration, itm types.IntervalTypeMetadata) *DInterval { - ret := &DInterval{Duration: d} - truncateDInterval(ret, itm) - return ret -} - // ParseDInterval parses and returns the *DInterval Datum value represented by the provided // string, or an error if parsing is unsuccessful. func ParseDInterval(s string) (*DInterval, error) { @@ -2664,55 +1255,6 @@ func (*DInterval) ResolvedType() *types.T { return types.Interval } -// Compare implements the Datum interface. -func (d *DInterval) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DInterval) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - return d.Duration.Compare(v.Duration) -} - -// Prev implements the Datum interface. -func (d *DInterval) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DInterval) Next(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// IsMax implements the Datum interface. -func (d *DInterval) IsMax(_ *EvalContext) bool { - return d.Duration == dMaxInterval.Duration -} - -// IsMin implements the Datum interface. -func (d *DInterval) IsMin(_ *EvalContext) bool { - return d.Duration == dMinInterval.Duration -} - -var ( - dZeroInterval = &DInterval{} - dMaxInterval = &DInterval{duration.MakeDuration(math.MaxInt64, math.MaxInt64, math.MaxInt64)} - dMinInterval = &DInterval{duration.MakeDuration(math.MinInt64, math.MinInt64, math.MinInt64)} -) - -// Max implements the Datum interface. -func (d *DInterval) Max(_ *EvalContext) (Datum, bool) { - return dMaxInterval, true -} - -// Min implements the Datum interface. -func (d *DInterval) Min(_ *EvalContext) (Datum, bool) { - return dMinInterval, true -} - // ValueAsString returns the interval as a string (e.g. "1h2m"). func (d *DInterval) ValueAsString() string { return d.Duration.String() @@ -2749,30 +1291,6 @@ func NewDGeography(g geo.Geography) *DGeography { return &DGeography{Geography: g} } -// AsDGeography attempts to retrieve a *DGeography from an Expr, returning a -// *DGeography and a flag signifying whether the assertion was successful. The -// function should be used instead of direct type assertions wherever a -// *DGeography wrapped by a *DOidWrapper is possible. -func AsDGeography(e Expr) (*DGeography, bool) { - switch t := e.(type) { - case *DGeography: - return t, true - case *DOidWrapper: - return AsDGeography(t.Wrapped) - } - return nil, false -} - -// MustBeDGeography attempts to retrieve a *DGeography from an Expr, panicking -// if the assertion fails. -func MustBeDGeography(e Expr) *DGeography { - i, ok := AsDGeography(e) - if !ok { - panic(errors.AssertionFailedf("expected *DGeography, found %T", e)) - } - return i -} - // ParseDGeography attempts to pass `str` as a Geography type. func ParseDGeography(str string) (*DGeography, error) { g, err := geo.ParseGeography(str) @@ -2787,45 +1305,6 @@ func (*DGeography) ResolvedType() *types.T { return types.Geography } -// Compare implements the Datum interface. -func (d *DGeography) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - return d.Geography.Compare(other.(*DGeography).Geography) -} - -// Prev implements the Datum interface. -func (d *DGeography) Prev(ctx *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DGeography) Next(ctx *EvalContext) (Datum, bool) { - return nil, false -} - -// IsMax implements the Datum interface. -func (d *DGeography) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DGeography) IsMin(_ *EvalContext) bool { - return false -} - -// Max implements the Datum interface. -func (d *DGeography) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Min implements the Datum interface. -func (d *DGeography) Min(_ *EvalContext) (Datum, bool) { - return nil, false -} - // AmbiguousFormat implements the Datum interface. func (*DGeography) AmbiguousFormat() bool { return true } @@ -2857,30 +1336,6 @@ func NewDGeometry(g geo.Geometry) *DGeometry { return &DGeometry{Geometry: g} } -// AsDGeometry attempts to retrieve a *DGeometry from an Expr, returning a -// *DGeometry and a flag signifying whether the assertion was successful. The -// function should be used instead of direct type assertions wherever a -// *DGeometry wrapped by a *DOidWrapper is possible. -func AsDGeometry(e Expr) (*DGeometry, bool) { - switch t := e.(type) { - case *DGeometry: - return t, true - case *DOidWrapper: - return AsDGeometry(t.Wrapped) - } - return nil, false -} - -// MustBeDGeometry attempts to retrieve a *DGeometry from an Expr, panicking -// if the assertion fails. -func MustBeDGeometry(e Expr) *DGeometry { - i, ok := AsDGeometry(e) - if !ok { - panic(errors.AssertionFailedf("expected *DGeometry, found %T", e)) - } - return i -} - // ParseDGeometry attempts to pass `str` as a Geometry type. func ParseDGeometry(str string) (*DGeometry, error) { g, err := geo.ParseGeometry(str) @@ -2895,45 +1350,6 @@ func (*DGeometry) ResolvedType() *types.T { return types.Geometry } -// Compare implements the Datum interface. -func (d *DGeometry) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - return d.Geometry.Compare(other.(*DGeometry).Geometry) -} - -// Prev implements the Datum interface. -func (d *DGeometry) Prev(ctx *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DGeometry) Next(ctx *EvalContext) (Datum, bool) { - return nil, false -} - -// IsMax implements the Datum interface. -func (d *DGeometry) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DGeometry) IsMin(_ *EvalContext) bool { - return false -} - -// Max implements the Datum interface. -func (d *DGeometry) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Min implements the Datum interface. -func (d *DGeometry) Min(_ *EvalContext) (Datum, bool) { - return nil, false -} - // AmbiguousFormat implements the Datum interface. func (*DGeometry) AmbiguousFormat() bool { return true } @@ -2974,75 +1390,11 @@ func ParseDBox2D(str string) (*DBox2D, error) { return &DBox2D{CartesianBoundingBox: b}, nil } -// AsDBox2D attempts to retrieve a *DBox2D from an Expr, returning a -// *DBox2D and a flag signifying whether the assertion was successful. The -// function should be used instead of direct type assertions wherever a -// *DBox2D wrapped by a *DOidWrapper is possible. -func AsDBox2D(e Expr) (*DBox2D, bool) { - switch t := e.(type) { - case *DBox2D: - return t, true - case *DOidWrapper: - return AsDBox2D(t.Wrapped) - } - return nil, false -} - -// MustBeDBox2D attempts to retrieve a *DBox2D from an Expr, panicking -// if the assertion fails. -func MustBeDBox2D(e Expr) *DBox2D { - i, ok := AsDBox2D(e) - if !ok { - panic(errors.AssertionFailedf("expected *DBox2D, found %T", e)) - } - return i -} - // ResolvedType implements the TypedExpr interface. func (*DBox2D) ResolvedType() *types.T { return types.Box2D } -// Compare implements the Datum interface. -func (d *DBox2D) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - o := other.(*DBox2D) - return d.CartesianBoundingBox.Compare(&o.CartesianBoundingBox) -} - -// Prev implements the Datum interface. -func (d *DBox2D) Prev(ctx *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DBox2D) Next(ctx *EvalContext) (Datum, bool) { - return nil, false -} - -// IsMax implements the Datum interface. -func (d *DBox2D) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DBox2D) IsMin(_ *EvalContext) bool { - return false -} - -// Max implements the Datum interface. -func (d *DBox2D) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Min implements the Datum interface. -func (d *DBox2D) Min(_ *EvalContext) (Datum, bool) { - return nil, false -} - // AmbiguousFormat implements the Datum interface. func (*DBox2D) AmbiguousFormat() bool { return true } @@ -3081,177 +1433,11 @@ func ParseDJSON(s string) (Datum, error) { return NewDJSON(j), nil } -// MakeDJSON returns a JSON value given a Go-style representation of JSON. -// * JSON null is Go `nil`, -// * JSON true is Go `true`, -// * JSON false is Go `false`, -// * JSON numbers are json.Number | int | int64 | float64, -// * JSON string is a Go string, -// * JSON array is a Go []interface{}, -// * JSON object is a Go map[string]interface{}. -func MakeDJSON(d interface{}) (Datum, error) { - j, err := json.MakeJSON(d) - if err != nil { - return nil, err - } - return &DJSON{j}, nil -} - -var dNullJSON = NewDJSON(json.NullJSONValue) - -// AsDJSON attempts to retrieve a *DJSON from an Expr, returning a *DJSON and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions wherever a *DJSON wrapped by a -// *DOidWrapper is possible. -func AsDJSON(e Expr) (*DJSON, bool) { - switch t := e.(type) { - case *DJSON: - return t, true - case *DOidWrapper: - return AsDJSON(t.Wrapped) - } - return nil, false -} - -// MustBeDJSON attempts to retrieve a DJSON from an Expr, panicking if the -// assertion fails. -func MustBeDJSON(e Expr) DJSON { - i, ok := AsDJSON(e) - if !ok { - panic(errors.AssertionFailedf("expected *DJSON, found %T", e)) - } - return *i -} - -// AsJSON converts a datum into our standard json representation. -func AsJSON(d Datum, loc *time.Location) (json.JSON, error) { - switch t := d.(type) { - case *DBool: - return json.FromBool(bool(*t)), nil - case *DInt: - return json.FromInt(int(*t)), nil - case *DFloat: - return json.FromFloat64(float64(*t)) - case *DDecimal: - return json.FromDecimal(t.Decimal), nil - case *DString: - return json.FromString(string(*t)), nil - case *DCollatedString: - return json.FromString(t.Contents), nil - case *DEnum: - return json.FromString(t.LogicalRep), nil - case *DJSON: - return t.JSON, nil - case *DArray: - builder := json.NewArrayBuilder(t.Len()) - for _, e := range t.Array { - j, err := AsJSON(e, loc) - if err != nil { - return nil, err - } - builder.Add(j) - } - return builder.Build(), nil - case *DTuple: - builder := json.NewObjectBuilder(len(t.D)) - // We need to make sure that t.typ is initialized before getting the tuple - // labels (it is valid for t.typ be left uninitialized when instantiating a - // DTuple). - t.maybePopulateType() - labels := t.typ.TupleLabels() - for i, e := range t.D { - j, err := AsJSON(e, loc) - if err != nil { - return nil, err - } - var key string - if i >= len(labels) { - key = fmt.Sprintf("f%d", i+1) - } else { - key = labels[i] - } - builder.Add(key, j) - } - return builder.Build(), nil - case *DTimestampTZ: - // Our normal timestamp-formatting code uses a variation on RFC 3339, - // without the T separator. This causes some compatibility problems - // with certain JSON consumers, so we'll use an alternate formatting - // path here to maintain consistency with PostgreSQL. - return json.FromString(t.Time.In(loc).Format(time.RFC3339Nano)), nil - case *DTimestamp: - // This is RFC3339Nano, but without the TZ fields. - return json.FromString(t.UTC().Format("2006-01-02T15:04:05.999999999")), nil - case *DDate, *DUuid, *DOid, *DInterval, *DBytes, *DIPAddr, *DTime, *DTimeTZ, *DBitArray, *DBox2D: - return json.FromString(AsStringWithFlags(t, FmtBareStrings)), nil - case *DGeometry: - return json.FromSpatialObject(t.Geometry.SpatialObject(), geo.DefaultGeoJSONDecimalDigits) - case *DGeography: - return json.FromSpatialObject(t.Geography.SpatialObject(), geo.DefaultGeoJSONDecimalDigits) - default: - if d == DNull { - return json.NullJSONValue, nil - } - - return nil, errors.AssertionFailedf("unexpected type %T for AsJSON", d) - } -} - // ResolvedType implements the TypedExpr interface. func (*DJSON) ResolvedType() *types.T { return types.Jsonb } -// Compare implements the Datum interface. -func (d *DJSON) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DJSON) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - // No avenue for us to pass up this error here at the moment, but Compare - // only errors for invalid encoded data. - // TODO(justin): modify Compare to allow passing up errors. - c, err := d.JSON.Compare(v.JSON) - if err != nil { - panic(err) - } - return c -} - -// Prev implements the Datum interface. -func (d *DJSON) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DJSON) Next(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// IsMax implements the Datum interface. -func (d *DJSON) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DJSON) IsMin(_ *EvalContext) bool { - return d.JSON == json.NullJSONValue -} - -// Max implements the Datum interface. -func (d *DJSON) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Min implements the Datum interface. -func (d *DJSON) Min(_ *EvalContext) (Datum, bool) { - return &DJSON{json.NullJSONValue}, true -} - // AmbiguousFormat implements the Datum interface. func (*DJSON) AmbiguousFormat() bool { return true } @@ -3273,211 +1459,41 @@ func (d *DJSON) Size() uintptr { return unsafe.Sizeof(*d) + d.JSON.Size() } -// DTuple is the tuple Datum. -type DTuple struct { - D Datums - - // sorted indicates that the values in D are pre-sorted. - // This is used to accelerate IN comparisons. - sorted bool - - // typ is the tuple's type. - // - // The Types sub-field can be initially uninitialized, and is then - // populated upon first invocation of ResolvedTypes(). If - // initialized it must have the same arity as D. - // - // The Labels sub-field can be left nil. If populated, it must have - // the same arity as D. - typ *types.T -} - -// NewDTuple creates a *DTuple with the provided datums. When creating a new -// DTuple with Datums that are known to be sorted in ascending order, chain -// this call with DTuple.SetSorted. -func NewDTuple(typ *types.T, d ...Datum) *DTuple { - return &DTuple{D: d, typ: typ} -} - -// NewDTupleWithLen creates a *DTuple with the provided length. -func NewDTupleWithLen(typ *types.T, l int) *DTuple { - return &DTuple{D: make(Datums, l), typ: typ} -} - -// AsDTuple attempts to retrieve a *DTuple from an Expr, returning a *DTuple and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions wherever a *DTuple wrapped by a -// *DOidWrapper is possible. -func AsDTuple(e Expr) (*DTuple, bool) { - switch t := e.(type) { - case *DTuple: - return t, true - case *DOidWrapper: - return AsDTuple(t.Wrapped) - } - return nil, false -} - -// MustBeDTuple attempts to retrieve a *DTuple from an Expr, panicking if the -// assertion fails. -func MustBeDTuple(e Expr) *DTuple { - i, ok := AsDTuple(e) - if !ok { - panic(errors.AssertionFailedf("expected *DTuple, found %T", e)) - } - return i -} - -// maybePopulateType populates the tuple's type if it hasn't yet been -// populated. -func (d *DTuple) maybePopulateType() { - if d.typ == nil { - contents := make([]*types.T, len(d.D)) - for i, v := range d.D { - contents[i] = v.ResolvedType() - } - d.typ = types.MakeTuple(contents) - } -} - -// ResolvedType implements the TypedExpr interface. -func (d *DTuple) ResolvedType() *types.T { - d.maybePopulateType() - return d.typ -} - -// Compare implements the Datum interface. -func (d *DTuple) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DTuple) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - n := len(d.D) - if n > len(v.D) { - n = len(v.D) - } - for i := 0; i < n; i++ { - c := d.D[i].Compare(ctx, v.D[i]) - if c != 0 { - return c - } - } - if len(d.D) < len(v.D) { - return -1 - } - if len(d.D) > len(v.D) { - return 1 - } - return 0 -} - -// Prev implements the Datum interface. -func (d *DTuple) Prev(ctx *EvalContext) (Datum, bool) { - // Note: (a:decimal, b:int, c:int) has a prev value; that's (a, b, - // c-1). With an exception if c is MinInt64, in which case the prev - // value is (a, b-1, max(_ *EvalContext)). However, (a:int, b:decimal) does not - // have a prev value, because decimal doesn't have one. - // - // In general, a tuple has a prev value if and only if it ends with - // zero or more values that are a minimum and a maximum value of the - // same type exists, and the first element before that has a prev - // value. - res := NewDTupleWithLen(d.typ, len(d.D)) - copy(res.D, d.D) - for i := len(res.D) - 1; i >= 0; i-- { - if !res.D[i].IsMin(ctx) { - prevVal, ok := res.D[i].Prev(ctx) - if !ok { - return nil, false - } - res.D[i] = prevVal - break - } - maxVal, ok := res.D[i].Max(ctx) - if !ok { - return nil, false - } - res.D[i] = maxVal - } - return res, true -} - -// Next implements the Datum interface. -func (d *DTuple) Next(ctx *EvalContext) (Datum, bool) { - // Note: (a:decimal, b:int, c:int) has a next value; that's (a, b, - // c+1). With an exception if c is MaxInt64, in which case the next - // value is (a, b+1, min(_ *EvalContext)). However, (a:int, b:decimal) does not - // have a next value, because decimal doesn't have one. - // - // In general, a tuple has a next value if and only if it ends with - // zero or more values that are a maximum and a minimum value of the - // same type exists, and the first element before that has a next - // value. - res := NewDTupleWithLen(d.typ, len(d.D)) - copy(res.D, d.D) - for i := len(res.D) - 1; i >= 0; i-- { - if !res.D[i].IsMax(ctx) { - nextVal, ok := res.D[i].Next(ctx) - if !ok { - return nil, false - } - res.D[i] = nextVal - break - } - // TODO(#12022): temporary workaround; see the interface comment. - res.D[i] = DNull - } - return res, true -} - -// Max implements the Datum interface. -func (d *DTuple) Max(ctx *EvalContext) (Datum, bool) { - res := NewDTupleWithLen(d.typ, len(d.D)) - for i, v := range d.D { - m, ok := v.Max(ctx) - if !ok { - return nil, false - } - res.D[i] = m - } - return res, true -} +// DTuple is the tuple Datum. +type DTuple struct { + D Datums -// Min implements the Datum interface. -func (d *DTuple) Min(ctx *EvalContext) (Datum, bool) { - res := NewDTupleWithLen(d.typ, len(d.D)) - for i, v := range d.D { - m, ok := v.Min(ctx) - if !ok { - return nil, false - } - res.D[i] = m - } - return res, true + // sorted indicates that the values in D are pre-sorted. + // This is used to accelerate IN comparisons. + sorted bool + + // typ is the tuple's type. + // + // The Types sub-field can be initially uninitialized, and is then + // populated upon first invocation of ResolvedTypes(). If + // initialized it must have the same arity as D. + // + // The Labels sub-field can be left nil. If populated, it must have + // the same arity as D. + typ *types.T } -// IsMax implements the Datum interface. -func (d *DTuple) IsMax(ctx *EvalContext) bool { - for _, v := range d.D { - if !v.IsMax(ctx) { - return false +// maybePopulateType populates the tuple's type if it hasn't yet been +// populated. +func (d *DTuple) maybePopulateType() { + if d.typ == nil { + contents := make([]*types.T, len(d.D)) + for i, v := range d.D { + contents[i] = v.ResolvedType() } + d.typ = types.MakeTuple(contents) } - return true } -// IsMin implements the Datum interface. -func (d *DTuple) IsMin(ctx *EvalContext) bool { - for _, v := range d.D { - if !v.IsMin(ctx) { - return false - } - } - return true +// ResolvedType implements the TypedExpr interface. +func (d *DTuple) ResolvedType() *types.T { + d.maybePopulateType() + return d.typ } // AmbiguousFormat implements the Datum interface. @@ -3555,54 +1571,6 @@ func (d *DTuple) AssertSorted() { } } -// SearchSorted searches the tuple for the target Datum, returning an int with -// the same contract as sort.Search and a boolean flag signifying whether the datum -// was found. It assumes that the DTuple is sorted and panics if it is not. -// -// The target Datum cannot be NULL or a DTuple that contains NULLs (we cannot -// binary search in this case; for example `(1, NULL) IN ((1, 2), ..)` needs to -// be -func (d *DTuple) SearchSorted(ctx *EvalContext, target Datum) (int, bool) { - d.AssertSorted() - if target == DNull { - panic(errors.AssertionFailedf("NULL target (d: %s)", d)) - } - if t, ok := target.(*DTuple); ok && t.ContainsNull() { - panic(errors.AssertionFailedf("target containing NULLs: %#v (d: %s)", target, d)) - } - i := sort.Search(len(d.D), func(i int) bool { - return d.D[i].Compare(ctx, target) >= 0 - }) - found := i < len(d.D) && d.D[i].Compare(ctx, target) == 0 - return i, found -} - -// Normalize sorts and uniques the datum tuple. -func (d *DTuple) Normalize(ctx *EvalContext) { - d.sort(ctx) - d.makeUnique(ctx) -} - -func (d *DTuple) sort(ctx *EvalContext) { - if !d.sorted { - sort.Slice(d.D, func(i, j int) bool { - return d.D[i].Compare(ctx, d.D[j]) < 0 - }) - d.SetSorted() - } -} - -func (d *DTuple) makeUnique(ctx *EvalContext) { - n := 0 - for i := 0; i < len(d.D); i++ { - if n == 0 || d.D[n-1].Compare(ctx, d.D[i]) < 0 { - d.D[n] = d.D[i] - n++ - } - } - d.D = d.D[:n] -} - // Size implements the Datum interface. func (d *DTuple) Size() uintptr { sz := unsafe.Sizeof(*d) @@ -3639,44 +1607,6 @@ func (dNull) ResolvedType() *types.T { return types.Unknown } -// Compare implements the Datum interface. -func (d dNull) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - return 0 - } - return -1 -} - -// Prev implements the Datum interface. -func (d dNull) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d dNull) Next(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// IsMax implements the Datum interface. -func (dNull) IsMax(_ *EvalContext) bool { - return true -} - -// IsMin implements the Datum interface. -func (dNull) IsMin(_ *EvalContext) bool { - return true -} - -// Max implements the Datum interface. -func (dNull) Max(_ *EvalContext) (Datum, bool) { - return DNull, true -} - -// Min implements the Datum interface. -func (dNull) Min(_ *EvalContext) (Datum, bool) { - return DNull, true -} - // AmbiguousFormat implements the Datum interface. func (dNull) AmbiguousFormat() bool { return false } @@ -3772,68 +1702,6 @@ func (d *DArray) FirstIndex() int { return 1 } -// Compare implements the Datum interface. -func (d *DArray) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DArray) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - n := d.Len() - if n > v.Len() { - n = v.Len() - } - for i := 0; i < n; i++ { - c := d.Array[i].Compare(ctx, v.Array[i]) - if c != 0 { - return c - } - } - if d.Len() < v.Len() { - return -1 - } - if d.Len() > v.Len() { - return 1 - } - return 0 -} - -// Prev implements the Datum interface. -func (d *DArray) Prev(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Next implements the Datum interface. -func (d *DArray) Next(_ *EvalContext) (Datum, bool) { - a := DArray{ParamTyp: d.ParamTyp, Array: make(Datums, d.Len()+1)} - copy(a.Array, d.Array) - a.Array[len(a.Array)-1] = DNull - return &a, true -} - -// Max implements the Datum interface. -func (d *DArray) Max(_ *EvalContext) (Datum, bool) { - return nil, false -} - -// Min implements the Datum interface. -func (d *DArray) Min(_ *EvalContext) (Datum, bool) { - return &DArray{ParamTyp: d.ParamTyp}, true -} - -// IsMax implements the Datum interface. -func (d *DArray) IsMax(_ *EvalContext) bool { - return false -} - -// IsMin implements the Datum interface. -func (d *DArray) IsMin(_ *EvalContext) bool { - return d.Len() == 0 -} - // AmbiguousFormat implements the Datum interface. func (d *DArray) AmbiguousFormat() bool { // The type of the array is ambiguous if it is empty or all-null; when @@ -4016,19 +1884,6 @@ func MakeDEnumFromLogicalRepresentation(typ *types.T, rep string) (*DEnum, error }, nil } -// MakeAllDEnumsInType generates a slice of all values in an enum. -func MakeAllDEnumsInType(typ *types.T) []Datum { - result := make([]Datum, len(typ.TypeMeta.EnumData.LogicalRepresentations)) - for i := 0; i < len(result); i++ { - result[i] = &DEnum{ - EnumTyp: typ, - PhysicalRep: typ.TypeMeta.EnumData.PhysicalRepresentations[i], - LogicalRep: typ.TypeMeta.EnumData.LogicalRepresentations[i], - } - } - return result -} - // Format implements the NodeFormatter interface. func (d *DEnum) Format(ctx *FmtCtx) { if ctx.HasFlags(fmtStaticallyFormatUserDefinedTypes) { @@ -4058,98 +1913,6 @@ func (d *DEnum) ResolvedType() *types.T { return d.EnumTyp } -// Compare implements the Datum interface. -func (d *DEnum) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - return 1 - } - v, ok := UnwrapDatum(ctx, other).(*DEnum) - if !ok { - panic(makeUnsupportedComparisonMessage(d, other)) - } - return bytes.Compare(d.PhysicalRep, v.PhysicalRep) -} - -// Prev implements the Datum interface. -func (d *DEnum) Prev(ctx *EvalContext) (Datum, bool) { - idx, err := d.EnumTyp.EnumGetIdxOfPhysical(d.PhysicalRep) - if err != nil { - panic(err) - } - if idx == 0 { - return nil, false - } - enumData := d.EnumTyp.TypeMeta.EnumData - return &DEnum{ - EnumTyp: d.EnumTyp, - PhysicalRep: enumData.PhysicalRepresentations[idx-1], - LogicalRep: enumData.LogicalRepresentations[idx-1], - }, true -} - -// Next implements the Datum interface. -func (d *DEnum) Next(ctx *EvalContext) (Datum, bool) { - idx, err := d.EnumTyp.EnumGetIdxOfPhysical(d.PhysicalRep) - if err != nil { - panic(err) - } - enumData := d.EnumTyp.TypeMeta.EnumData - if idx == len(enumData.PhysicalRepresentations)-1 { - return nil, false - } - return &DEnum{ - EnumTyp: d.EnumTyp, - PhysicalRep: enumData.PhysicalRepresentations[idx+1], - LogicalRep: enumData.LogicalRepresentations[idx+1], - }, true -} - -// Max implements the Datum interface. -func (d *DEnum) Max(ctx *EvalContext) (Datum, bool) { - enumData := d.EnumTyp.TypeMeta.EnumData - if len(enumData.PhysicalRepresentations) == 0 { - return nil, false - } - idx := len(enumData.PhysicalRepresentations) - 1 - return &DEnum{ - EnumTyp: d.EnumTyp, - PhysicalRep: enumData.PhysicalRepresentations[idx], - LogicalRep: enumData.LogicalRepresentations[idx], - }, true -} - -// Min implements the Datum interface. -func (d *DEnum) Min(ctx *EvalContext) (Datum, bool) { - enumData := d.EnumTyp.TypeMeta.EnumData - if len(enumData.PhysicalRepresentations) == 0 { - return nil, false - } - return &DEnum{ - EnumTyp: d.EnumTyp, - PhysicalRep: enumData.PhysicalRepresentations[0], - LogicalRep: enumData.LogicalRepresentations[0], - }, true -} - -// IsMax implements the Datum interface. -func (d *DEnum) IsMax(_ *EvalContext) bool { - physReps := d.EnumTyp.TypeMeta.EnumData.PhysicalRepresentations - idx, err := d.EnumTyp.EnumGetIdxOfPhysical(d.PhysicalRep) - if err != nil { - panic(err) - } - return idx == len(physReps)-1 -} - -// IsMin implements the Datum interface. -func (d *DEnum) IsMin(_ *EvalContext) bool { - idx, err := d.EnumTyp.EnumGetIdxOfPhysical(d.PhysicalRep) - if err != nil { - panic(err) - } - return idx == 0 -} - // AmbiguousFormat implements the Datum interface. func (d *DEnum) AmbiguousFormat() bool { return true @@ -4214,40 +1977,6 @@ func NewDOid(d DInt) *DOid { return &oid } -// AsDOid attempts to retrieve a DOid from an Expr, returning a DOid and -// a flag signifying whether the assertion was successful. The function should -// be used instead of direct type assertions wherever a *DOid wrapped by a -// *DOidWrapper is possible. -func AsDOid(e Expr) (*DOid, bool) { - switch t := e.(type) { - case *DOid: - return t, true - case *DOidWrapper: - return AsDOid(t.Wrapped) - } - return NewDOid(0), false -} - -// MustBeDOid attempts to retrieve a DOid from an Expr, panicking if the -// assertion fails. -func MustBeDOid(e Expr) *DOid { - i, ok := AsDOid(e) - if !ok { - panic(errors.AssertionFailedf("expected *DOid, found %T", e)) - } - return i -} - -// NewDOidWithName is a helper routine to create a *DOid initialized from a DInt -// and a string. -func NewDOidWithName(d DInt, typ *types.T, name string) *DOid { - return &DOid{ - DInt: d, - semanticType: typ, - name: name, - } -} - // AsRegProc changes the input DOid into a regproc with the given name and // returns it. func (d *DOid) AsRegProc(name string) *DOid { @@ -4259,31 +1988,6 @@ func (d *DOid) AsRegProc(name string) *DOid { // AmbiguousFormat implements the Datum interface. func (*DOid) AmbiguousFormat() bool { return true } -// Compare implements the Datum interface. -func (d *DOid) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - var v DInt - switch t := UnwrapDatum(ctx, other).(type) { - case *DOid: - v = t.DInt - case *DInt: - v = *t - default: - panic(makeUnsupportedComparisonMessage(d, other)) - } - - if d.DInt < v { - return -1 - } - if d.DInt > v { - return 1 - } - return 0 -} - // Format implements the Datum interface. func (d *DOid) Format(ctx *FmtCtx) { if d.semanticType.Oid() == oid.T_oid || d.name == "" { @@ -4308,24 +2012,6 @@ func (d *DOid) Format(ctx *FmtCtx) { } } -// IsMax implements the Datum interface. -func (d *DOid) IsMax(ctx *EvalContext) bool { return d.DInt.IsMax(ctx) } - -// IsMin implements the Datum interface. -func (d *DOid) IsMin(ctx *EvalContext) bool { return d.DInt.IsMin(ctx) } - -// Next implements the Datum interface. -func (d *DOid) Next(ctx *EvalContext) (Datum, bool) { - next, ok := d.DInt.Next(ctx) - return &DOid{*next.(*DInt), d.semanticType, ""}, ok -} - -// Prev implements the Datum interface. -func (d *DOid) Prev(ctx *EvalContext) (Datum, bool) { - prev, ok := d.DInt.Prev(ctx) - return &DOid{*prev.(*DInt), d.semanticType, ""}, ok -} - // ResolvedType implements the Datum interface. func (d *DOid) ResolvedType() *types.T { return d.semanticType @@ -4334,18 +2020,6 @@ func (d *DOid) ResolvedType() *types.T { // Size implements the Datum interface. func (d *DOid) Size() uintptr { return unsafe.Sizeof(*d) } -// Max implements the Datum interface. -func (d *DOid) Max(ctx *EvalContext) (Datum, bool) { - max, ok := d.DInt.Max(ctx) - return &DOid{*max.(*DInt), d.semanticType, ""}, ok -} - -// Min implements the Datum interface. -func (d *DOid) Min(ctx *EvalContext) (Datum, bool) { - min, ok := d.DInt.Min(ctx) - return &DOid{*min.(*DInt), d.semanticType, ""}, ok -} - // DOidWrapper is a Datum implementation which is a wrapper around a Datum, allowing // custom Oid values to be attached to the Datum and its types.T. // The reason the Datum type was introduced was to permit the introduction of Datum @@ -4394,20 +2068,10 @@ func wrapWithOid(d Datum, oid oid.Oid) Datum { // UnwrapDatum returns the base Datum type for a provided datum, stripping // an *DOidWrapper if present. This is useful for cases like type switches, // where type aliases should be ignored. -func UnwrapDatum(evalCtx *EvalContext, d Datum) Datum { +func UnwrapDatum(d Datum) Datum { if w, ok := d.(*DOidWrapper); ok { return w.Wrapped } - if p, ok := d.(*Placeholder); ok && evalCtx != nil && evalCtx.HasPlaceholders() { - ret, err := p.Eval(evalCtx) - if err != nil { - // If we fail to evaluate the placeholder, it's because we don't have - // a placeholder available. Just return the placeholder and someone else - // will handle this problem. - return d - } - return ret - } return d } @@ -4416,52 +2080,6 @@ func (d *DOidWrapper) ResolvedType() *types.T { return types.OidToType[d.Oid] } -// Compare implements the Datum interface. -func (d *DOidWrapper) Compare(ctx *EvalContext, other Datum) int { - if other == DNull { - // NULL is less than any non-NULL value. - return 1 - } - if v, ok := other.(*DOidWrapper); ok { - return d.Wrapped.Compare(ctx, v.Wrapped) - } - return d.Wrapped.Compare(ctx, other) -} - -// Prev implements the Datum interface. -func (d *DOidWrapper) Prev(ctx *EvalContext) (Datum, bool) { - prev, ok := d.Wrapped.Prev(ctx) - return wrapWithOid(prev, d.Oid), ok -} - -// Next implements the Datum interface. -func (d *DOidWrapper) Next(ctx *EvalContext) (Datum, bool) { - next, ok := d.Wrapped.Next(ctx) - return wrapWithOid(next, d.Oid), ok -} - -// IsMax implements the Datum interface. -func (d *DOidWrapper) IsMax(ctx *EvalContext) bool { - return d.Wrapped.IsMax(ctx) -} - -// IsMin implements the Datum interface. -func (d *DOidWrapper) IsMin(ctx *EvalContext) bool { - return d.Wrapped.IsMin(ctx) -} - -// Max implements the Datum interface. -func (d *DOidWrapper) Max(ctx *EvalContext) (Datum, bool) { - max, ok := d.Wrapped.Max(ctx) - return wrapWithOid(max, d.Oid), ok -} - -// Min implements the Datum interface. -func (d *DOidWrapper) Min(ctx *EvalContext) (Datum, bool) { - min, ok := d.Wrapped.Min(ctx) - return wrapWithOid(min, d.Oid), ok -} - // AmbiguousFormat implements the Datum interface. func (d *DOidWrapper) AmbiguousFormat() bool { return d.Wrapped.AmbiguousFormat() @@ -4483,53 +2101,6 @@ func (d *Placeholder) AmbiguousFormat() bool { return true } -func (d *Placeholder) mustGetValue(ctx *EvalContext) Datum { - e, ok := ctx.Placeholders.Value(d.Idx) - if !ok { - panic(errors.AssertionFailedf("fail")) - } - out, err := e.Eval(ctx) - if err != nil { - panic(errors.NewAssertionErrorWithWrappedErrf(err, "fail")) - } - return out -} - -// Compare implements the Datum interface. -func (d *Placeholder) Compare(ctx *EvalContext, other Datum) int { - return d.mustGetValue(ctx).Compare(ctx, other) -} - -// Prev implements the Datum interface. -func (d *Placeholder) Prev(ctx *EvalContext) (Datum, bool) { - return d.mustGetValue(ctx).Prev(ctx) -} - -// IsMin implements the Datum interface. -func (d *Placeholder) IsMin(ctx *EvalContext) bool { - return d.mustGetValue(ctx).IsMin(ctx) -} - -// Next implements the Datum interface. -func (d *Placeholder) Next(ctx *EvalContext) (Datum, bool) { - return d.mustGetValue(ctx).Next(ctx) -} - -// IsMax implements the Datum interface. -func (d *Placeholder) IsMax(ctx *EvalContext) bool { - return d.mustGetValue(ctx).IsMax(ctx) -} - -// Max implements the Datum interface. -func (d *Placeholder) Max(ctx *EvalContext) (Datum, bool) { - return d.mustGetValue(ctx).Max(ctx) -} - -// Min implements the Datum interface. -func (d *Placeholder) Min(ctx *EvalContext) (Datum, bool) { - return d.mustGetValue(ctx).Min(ctx) -} - // Size implements the Datum interface. func (d *Placeholder) Size() uintptr { panic(errors.AssertionFailedf("shouldn't get called")) @@ -4540,178 +2111,3 @@ func (d *Placeholder) Size() uintptr { func NewDNameFromDString(d *DString) Datum { return wrapWithOid(d, oid.T_name) } - -// NewDName is a helper routine to create a *DName (implemented as a *DOidWrapper) -// initialized from a string. -func NewDName(d string) Datum { - return NewDNameFromDString(NewDString(d)) -} - -// NewDIntVectorFromDArray is a helper routine to create a *DIntVector -// (implemented as a *DOidWrapper) initialized from an existing *DArray. -func NewDIntVectorFromDArray(d *DArray) Datum { - ret := new(DArray) - *ret = *d - ret.customOid = oid.T_int2vector - return ret -} - -// NewDOidVectorFromDArray is a helper routine to create a *DOidVector -// (implemented as a *DOidWrapper) initialized from an existing *DArray. -func NewDOidVectorFromDArray(d *DArray) Datum { - ret := new(DArray) - *ret = *d - ret.customOid = oid.T_oidvector - return ret -} - -// NewDefaultDatum returns a default non-NULL datum value for the given type. -// This is used when updating non-NULL columns that are being added or dropped -// from a table, and there is no user-defined DEFAULT value available. -func NewDefaultDatum(evalCtx *EvalContext, t *types.T) (d Datum, err error) { - switch t.Family() { - case types.BoolFamily: - return DBoolFalse, nil - case types.IntFamily: - return DZero, nil - case types.FloatFamily: - return dZeroFloat, nil - case types.DecimalFamily: - return dZeroDecimal, nil - case types.DateFamily: - return dEpochDate, nil - case types.TimestampFamily: - return dZeroTimestamp, nil - case types.IntervalFamily: - return dZeroInterval, nil - case types.StringFamily: - return dEmptyString, nil - case types.BytesFamily: - return dEmptyBytes, nil - case types.TimestampTZFamily: - return dZeroTimestampTZ, nil - case types.CollatedStringFamily: - return NewDCollatedString("", t.Locale(), &evalCtx.CollationEnv) - case types.OidFamily: - return NewDOidWithName(DInt(t.Oid()), t, t.SQLStandardName()), nil - case types.UnknownFamily: - return DNull, nil - case types.UuidFamily: - return DMinUUID, nil - case types.ArrayFamily: - return NewDArray(t.ArrayContents()), nil - case types.INetFamily: - return DMinIPAddr, nil - case types.TimeFamily: - return dTimeMin, nil - case types.JsonFamily: - return dNullJSON, nil - case types.TimeTZFamily: - return dZeroTimeTZ, nil - case types.GeometryFamily, types.GeographyFamily, types.Box2DFamily: - // TODO(otan): force Geometry/Geography to not allow `NOT NULL` columns to - // make this impossible. - return nil, pgerror.Newf( - pgcode.FeatureNotSupported, - "%s must be set or be NULL", - t.Name(), - ) - case types.TupleFamily: - contents := t.TupleContents() - datums := make([]Datum, len(contents)) - for i, subT := range contents { - datums[i], err = NewDefaultDatum(evalCtx, subT) - if err != nil { - return nil, err - } - } - return NewDTuple(t, datums...), nil - case types.BitFamily: - return bitArrayZero, nil - default: - return nil, errors.AssertionFailedf("unhandled type %v", t.SQLString()) - } -} - -// DatumTypeSize returns a lower bound on the total size of a Datum -// of the given type in bytes, including memory that is -// pointed at (even if shared between Datum instances) but excluding -// allocation overhead. -// -// The second return value indicates whether data of this type have different -// sizes. -// -// It holds for every Datum d that d.Size() >= DatumSize(d.ResolvedType()) -func DatumTypeSize(t *types.T) (size uintptr, isVarlen bool) { - // The following are composite types or types that support multiple widths. - switch t.Family() { - case types.TupleFamily: - if types.IsWildcardTupleType(t) { - return uintptr(0), false - } - sz := uintptr(0) - variable := false - for i := range t.TupleContents() { - typsz, typvariable := DatumTypeSize(t.TupleContents()[i]) - sz += typsz - variable = variable || typvariable - } - return sz, variable - case types.IntFamily, types.FloatFamily: - return uintptr(t.Width() / 8), false - - case types.StringFamily: - // T_char is a special string type that has a fixed size of 1. We have to - // report its size accurately, and that it's not a variable-length datatype. - if t.Oid() == oid.T_char { - return 1, false - } - } - - // All the primary types have fixed size information. - if bSzInfo, ok := baseDatumTypeSizes[t.Family()]; ok { - return bSzInfo.sz, bSzInfo.variable - } - - panic(errors.AssertionFailedf("unknown type: %T", t)) -} - -const ( - fixedSize = false - variableSize = true -) - -var baseDatumTypeSizes = map[types.Family]struct { - sz uintptr - variable bool -}{ - types.UnknownFamily: {unsafe.Sizeof(dNull{}), fixedSize}, - types.BoolFamily: {unsafe.Sizeof(DBool(false)), fixedSize}, - types.Box2DFamily: {unsafe.Sizeof(DBox2D{CartesianBoundingBox: geo.CartesianBoundingBox{}}), fixedSize}, - types.BitFamily: {unsafe.Sizeof(DBitArray{}), variableSize}, - types.IntFamily: {unsafe.Sizeof(DInt(0)), fixedSize}, - types.FloatFamily: {unsafe.Sizeof(DFloat(0.0)), fixedSize}, - types.DecimalFamily: {unsafe.Sizeof(DDecimal{}), variableSize}, - types.StringFamily: {unsafe.Sizeof(DString("")), variableSize}, - types.CollatedStringFamily: {unsafe.Sizeof(DCollatedString{"", "", nil}), variableSize}, - types.BytesFamily: {unsafe.Sizeof(DBytes("")), variableSize}, - types.DateFamily: {unsafe.Sizeof(DDate{}), fixedSize}, - types.GeographyFamily: {unsafe.Sizeof(DGeography{}), variableSize}, - types.GeometryFamily: {unsafe.Sizeof(DGeometry{}), variableSize}, - types.TimeFamily: {unsafe.Sizeof(DTime(0)), fixedSize}, - types.TimeTZFamily: {unsafe.Sizeof(DTimeTZ{}), fixedSize}, - types.TimestampFamily: {unsafe.Sizeof(DTimestamp{}), fixedSize}, - types.TimestampTZFamily: {unsafe.Sizeof(DTimestampTZ{}), fixedSize}, - types.IntervalFamily: {unsafe.Sizeof(DInterval{}), fixedSize}, - types.JsonFamily: {unsafe.Sizeof(DJSON{}), variableSize}, - types.UuidFamily: {unsafe.Sizeof(DUuid{}), fixedSize}, - types.INetFamily: {unsafe.Sizeof(DIPAddr{}), fixedSize}, - types.OidFamily: {unsafe.Sizeof(DInt(0)), fixedSize}, - types.EnumFamily: {unsafe.Sizeof(DEnum{}), variableSize}, - - // TODO(jordan,justin): This seems suspicious. - types.ArrayFamily: {unsafe.Sizeof(DString("")), variableSize}, - - // TODO(jordan,justin): This seems suspicious. - types.AnyFamily: {unsafe.Sizeof(DString("")), variableSize}, -} diff --git a/postgres/parser/sem/tree/eval.go b/postgres/parser/sem/tree/eval.go index 8544ff88c0..416ec54a39 100644 --- a/postgres/parser/sem/tree/eval.go +++ b/postgres/parser/sem/tree/eval.go @@ -27,29 +27,17 @@ package tree import ( "context" "fmt" - "math" "math/big" "regexp" - "strings" - "time" - "unicode/utf8" - "github.com/cockroachdb/apd/v2" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" - "github.com/dolthub/doltgresql/postgres/parser/duration" "github.com/dolthub/doltgresql/postgres/parser/geo" - "github.com/dolthub/doltgresql/postgres/parser/hlc" - "github.com/dolthub/doltgresql/postgres/parser/json" - "github.com/dolthub/doltgresql/postgres/parser/kv" "github.com/dolthub/doltgresql/postgres/parser/pgcode" "github.com/dolthub/doltgresql/postgres/parser/pgerror" "github.com/dolthub/doltgresql/postgres/parser/roleoption" - "github.com/dolthub/doltgresql/postgres/parser/sessiondata" - "github.com/dolthub/doltgresql/postgres/parser/timeofday" "github.com/dolthub/doltgresql/postgres/parser/types" - "github.com/dolthub/doltgresql/postgres/parser/utils" ) var ( @@ -70,18 +58,10 @@ var ( big10E10 = big.NewInt(1e10) ) -// NewCannotMixBitArraySizesError creates an error for the case when a bitwise -// aggregate function is called on bit arrays with different sizes. -func NewCannotMixBitArraySizesError(op string) error { - return pgerror.Newf(pgcode.StringDataLengthMismatch, - "cannot %s bit strings of different sizes", op) -} - // UnaryOp is a unary operator. type UnaryOp struct { Typ *types.T ReturnType *types.T - Fn func(*EvalContext, Datum) (Datum, error) Volatility Volatility types TypeList @@ -121,44 +101,21 @@ var UnaryOps = unaryOpFixups(map[UnaryOperator]unaryOpOverload{ &UnaryOp{ Typ: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - i := MustBeDInt(d) - if i == math.MinInt64 { - return nil, ErrIntOutOfRange - } - return NewDInt(-i), nil - }, Volatility: VolatilityImmutable, }, &UnaryOp{ Typ: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - return NewDFloat(-*d.(*DFloat)), nil - }, Volatility: VolatilityImmutable, }, &UnaryOp{ Typ: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - dec := &d.(*DDecimal).Decimal - dd := &DDecimal{} - dd.Decimal.Neg(dec) - return dd, nil - }, Volatility: VolatilityImmutable, }, &UnaryOp{ Typ: types.Interval, ReturnType: types.Interval, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - i := d.(*DInterval).Duration - i.SetNanos(-i.Nanos()) - i.Days = -i.Days - i.Months = -i.Months - return &DInterval{Duration: i}, nil - }, Volatility: VolatilityImmutable, }, }, @@ -167,27 +124,16 @@ var UnaryOps = unaryOpFixups(map[UnaryOperator]unaryOpOverload{ &UnaryOp{ Typ: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - return NewDInt(^MustBeDInt(d)), nil - }, Volatility: VolatilityImmutable, }, &UnaryOp{ Typ: types.VarBit, ReturnType: types.VarBit, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - p := MustBeDBitArray(d) - return &DBitArray{BitArray: utils.Not(p.BitArray)}, nil - }, Volatility: VolatilityImmutable, }, &UnaryOp{ Typ: types.INet, ReturnType: types.INet, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(d).IPAddr - return NewDIPAddr(DIPAddr{ipAddr.Complement()}), nil - }, Volatility: VolatilityImmutable, }, }, @@ -196,18 +142,11 @@ var UnaryOps = unaryOpFixups(map[UnaryOperator]unaryOpOverload{ &UnaryOp{ Typ: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - return Sqrt(float64(*d.(*DFloat))) - }, Volatility: VolatilityImmutable, }, &UnaryOp{ Typ: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - dec := &d.(*DDecimal).Decimal - return DecimalSqrt(dec) - }, Volatility: VolatilityImmutable, }, }, @@ -216,33 +155,22 @@ var UnaryOps = unaryOpFixups(map[UnaryOperator]unaryOpOverload{ &UnaryOp{ Typ: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - return Cbrt(float64(*d.(*DFloat))) - }, Volatility: VolatilityImmutable, }, &UnaryOp{ Typ: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, d Datum) (Datum, error) { - dec := &d.(*DDecimal).Decimal - return DecimalCbrt(dec) - }, Volatility: VolatilityImmutable, }, }, }) -// TwoArgFn is a function that operates on two Datum arguments. -type TwoArgFn func(*EvalContext, Datum, Datum) (Datum, error) - // BinOp is a binary operator. type BinOp struct { LeftType *types.T RightType *types.T ReturnType *types.T NullableArgs bool - Fn TwoArgFn Volatility Volatility types TypeList @@ -265,40 +193,6 @@ func (*BinOp) preferred() bool { return false } -// AppendToMaybeNullArray appends an element to an array. If the first -// argument is NULL, an array of one element is created. -func AppendToMaybeNullArray(typ *types.T, left Datum, right Datum) (Datum, error) { - result := NewDArray(typ) - if left != DNull { - for _, e := range MustBeDArray(left).Array { - if err := result.Append(e); err != nil { - return nil, err - } - } - } - if err := result.Append(right); err != nil { - return nil, err - } - return result, nil -} - -// PrependToMaybeNullArray prepends an element in the front of an arrray. -// If the argument is NULL, an array of one element is created. -func PrependToMaybeNullArray(typ *types.T, left Datum, right Datum) (Datum, error) { - result := NewDArray(typ) - if err := result.Append(left); err != nil { - return nil, err - } - if right != DNull { - for _, e := range MustBeDArray(right).Array { - if err := result.Append(e); err != nil { - return nil, err - } - } - } - return result, nil -} - // TODO(justin): these might be improved by making arrays into an interface and // then introducing a ConcatenatedArray implementation which just references two // existing arrays. This would optimize the common case of appending an element @@ -311,10 +205,7 @@ func initArrayElementConcatenation() { RightType: typ, ReturnType: types.MakeArray(typ), NullableArgs: true, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return AppendToMaybeNullArray(typ, left, right) - }, - Volatility: VolatilityImmutable, + Volatility: VolatilityImmutable, }) BinOps[Concat] = append(BinOps[Concat], &BinOp{ @@ -322,61 +213,11 @@ func initArrayElementConcatenation() { RightType: types.MakeArray(typ), ReturnType: types.MakeArray(typ), NullableArgs: true, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return PrependToMaybeNullArray(typ, left, right) - }, - Volatility: VolatilityImmutable, + Volatility: VolatilityImmutable, }) } } -// ConcatArrays concatenates two arrays. -func ConcatArrays(typ *types.T, left Datum, right Datum) (Datum, error) { - if left == DNull && right == DNull { - return DNull, nil - } - result := NewDArray(typ) - if left != DNull { - for _, e := range MustBeDArray(left).Array { - if err := result.Append(e); err != nil { - return nil, err - } - } - } - if right != DNull { - for _, e := range MustBeDArray(right).Array { - if err := result.Append(e); err != nil { - return nil, err - } - } - } - return result, nil -} - -// ArrayContains return true if the haystack contains all needles. -func ArrayContains(ctx *EvalContext, haystack *DArray, needles *DArray) (*DBool, error) { - if !haystack.ParamTyp.Equivalent(needles.ParamTyp) { - return DBoolFalse, pgerror.New(pgcode.DatatypeMismatch, "cannot compare arrays with different element types") - } - for _, needle := range needles.Array { - // Nulls don't compare to each other in @> syntax. - if needle == DNull { - return DBoolFalse, nil - } - var found bool - for _, hay := range haystack.Array { - if needle.Compare(ctx, hay) == 0 { - found = true - break - } - } - if !found { - return DBoolFalse, nil - } - } - return DBoolTrue, nil -} - func initArrayToArrayConcatenation() { for _, t := range types.Scalar { typ := t @@ -385,10 +226,7 @@ func initArrayToArrayConcatenation() { RightType: types.MakeArray(typ), ReturnType: types.MakeArray(typ), NullableArgs: true, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return ConcatArrays(typ, left, right) - }, - Volatility: VolatilityImmutable, + Volatility: VolatilityImmutable, }) } } @@ -422,27 +260,6 @@ func (o binOpOverload) lookupImpl(left, right *types.T) (*BinOp, bool) { return nil, false } -// getJSONPath is used for the #> and #>> operators. -func getJSONPath(j DJSON, ary DArray) (Datum, error) { - // TODO(justin): this is slightly annoying because we have to allocate - // a new array since the JSON package isn't aware of DArray. - path := make([]string, len(ary.Array)) - for i, v := range ary.Array { - if v == DNull { - return DNull, nil - } - path[i] = string(MustBeDString(v)) - } - result, err := json.FetchPath(j.JSON, path) - if err != nil { - return nil, err - } - if result == nil { - return DNull, nil - } - return &DJSON{result}, nil -} - // BinOps contains the binary operations indexed by operation type. var BinOps = map[BinaryOperator]binOpOverload{ Bitand: { @@ -450,37 +267,18 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDInt(MustBeDInt(left) & MustBeDInt(right)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.VarBit, RightType: types.VarBit, ReturnType: types.VarBit, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - lhs := MustBeDBitArray(left) - rhs := MustBeDBitArray(right) - if lhs.BitLen() != rhs.BitLen() { - return nil, NewCannotMixBitArraySizesError("AND") - } - return &DBitArray{ - BitArray: utils.And(lhs.BitArray, rhs.BitArray), - }, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.INet, RightType: types.INet, ReturnType: types.INet, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - other := MustBeDIPAddr(right).IPAddr - newIPAddr, err := ipAddr.And(&other) - return NewDIPAddr(DIPAddr{newIPAddr}), err - }, Volatility: VolatilityImmutable, }, }, @@ -490,37 +288,18 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDInt(MustBeDInt(left) | MustBeDInt(right)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.VarBit, RightType: types.VarBit, ReturnType: types.VarBit, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - lhs := MustBeDBitArray(left) - rhs := MustBeDBitArray(right) - if lhs.BitLen() != rhs.BitLen() { - return nil, NewCannotMixBitArraySizesError("OR") - } - return &DBitArray{ - BitArray: utils.Or(lhs.BitArray, rhs.BitArray), - }, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.INet, RightType: types.INet, ReturnType: types.INet, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - other := MustBeDIPAddr(right).IPAddr - newIPAddr, err := ipAddr.Or(&other) - return NewDIPAddr(DIPAddr{newIPAddr}), err - }, Volatility: VolatilityImmutable, }, }, @@ -530,25 +309,12 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDInt(MustBeDInt(left) ^ MustBeDInt(right)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.VarBit, RightType: types.VarBit, ReturnType: types.VarBit, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - lhs := MustBeDBitArray(left) - rhs := MustBeDBitArray(right) - if lhs.BitLen() != rhs.BitLen() { - return nil, NewCannotMixBitArraySizesError("XOR") - } - return &DBitArray{ - BitArray: utils.Xor(lhs.BitArray, rhs.BitArray), - }, nil - }, Volatility: VolatilityImmutable, }, }, @@ -558,292 +324,144 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - a, b := MustBeDInt(left), MustBeDInt(right) - r, ok := utils.AddWithOverflow(int64(a), int64(b)) - if !ok { - return nil, ErrIntOutOfRange - } - return NewDInt(DInt(r)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDFloat(*left.(*DFloat) + *right.(*DFloat)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - _, err := ExactCtx.Add(&dd.Decimal, l, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := MustBeDInt(right) - dd := &DDecimal{} - dd.SetInt64(int64(r)) - _, err := ExactCtx.Add(&dd.Decimal, l, &dd.Decimal) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := MustBeDInt(left) - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - dd.SetInt64(int64(l)) - _, err := ExactCtx.Add(&dd.Decimal, &dd.Decimal, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Date, RightType: types.Int, ReturnType: types.Date, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - d, err := left.(*DDate).AddDays(int64(MustBeDInt(right))) - if err != nil { - return nil, err - } - return NewDDate(d), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Date, ReturnType: types.Date, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - d, err := right.(*DDate).AddDays(int64(MustBeDInt(left))) - if err != nil { - return nil, err - } - return NewDDate(d), nil - - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Date, RightType: types.Time, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - leftTime, err := left.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := time.Duration(*right.(*DTime)) * time.Microsecond - return MakeDTimestamp(leftTime.Add(t), time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Time, RightType: types.Date, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - rightTime, err := right.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := time.Duration(*left.(*DTime)) * time.Microsecond - return MakeDTimestamp(rightTime.Add(t), time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Date, RightType: types.TimeTZ, ReturnType: types.TimestampTZ, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - leftTime, err := left.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := leftTime.Add(right.(*DTimeTZ).ToDuration()) - return MakeDTimestampTZ(t, time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.TimeTZ, RightType: types.Date, ReturnType: types.TimestampTZ, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - rightTime, err := right.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := rightTime.Add(left.(*DTimeTZ).ToDuration()) - return MakeDTimestampTZ(t, time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Time, RightType: types.Interval, ReturnType: types.Time, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - t := timeofday.TimeOfDay(*left.(*DTime)) - return MakeDTime(t.Add(right.(*DInterval).Duration)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Time, ReturnType: types.Time, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - t := timeofday.TimeOfDay(*right.(*DTime)) - return MakeDTime(t.Add(left.(*DInterval).Duration)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.TimeTZ, RightType: types.Interval, ReturnType: types.TimeTZ, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - t := left.(*DTimeTZ) - duration := right.(*DInterval).Duration - return NewDTimeTZFromOffset(t.Add(duration), t.OffsetSecs), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.TimeTZ, ReturnType: types.TimeTZ, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - t := right.(*DTimeTZ) - duration := left.(*DInterval).Duration - return NewDTimeTZFromOffset(t.Add(duration), t.OffsetSecs), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Timestamp, RightType: types.Interval, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return MakeDTimestamp(duration.Add( - left.(*DTimestamp).Time, right.(*DInterval).Duration), time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Timestamp, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return MakeDTimestamp(duration.Add( - right.(*DTimestamp).Time, left.(*DInterval).Duration), time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.TimestampTZ, RightType: types.Interval, ReturnType: types.TimestampTZ, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - // Convert time to be in the given timezone, as math relies on matching timezones.. - t := duration.Add(left.(*DTimestampTZ).Time.In(ctx.GetLocation()), right.(*DInterval).Duration) - return MakeDTimestampTZ(t, time.Microsecond) - }, Volatility: VolatilityStable, }, &BinOp{ LeftType: types.Interval, RightType: types.TimestampTZ, ReturnType: types.TimestampTZ, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - // Convert time to be in the given timezone, as math relies on matching timezones.. - t := duration.Add(right.(*DTimestampTZ).Time.In(ctx.GetLocation()), left.(*DInterval).Duration) - return MakeDTimestampTZ(t, time.Microsecond) - }, Volatility: VolatilityStable, }, &BinOp{ LeftType: types.Interval, RightType: types.Interval, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return &DInterval{Duration: left.(*DInterval).Duration.Add(right.(*DInterval).Duration)}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Date, RightType: types.Interval, ReturnType: types.Timestamp, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - leftTime, err := left.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := duration.Add(leftTime, right.(*DInterval).Duration) - return MakeDTimestamp(t, time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Date, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - rightTime, err := right.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := duration.Add(rightTime, left.(*DInterval).Duration) - return MakeDTimestamp(t, time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.INet, RightType: types.Int, ReturnType: types.INet, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - i := MustBeDInt(right) - newIPAddr, err := ipAddr.Add(int64(i)) - return NewDIPAddr(DIPAddr{newIPAddr}), err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.INet, ReturnType: types.INet, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - i := MustBeDInt(left) - ipAddr := MustBeDIPAddr(right).IPAddr - newIPAddr, err := ipAddr.Add(int64(i)) - return NewDIPAddr(DIPAddr{newIPAddr}), err - }, Volatility: VolatilityImmutable, }, }, @@ -853,298 +471,138 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - a, b := MustBeDInt(left), MustBeDInt(right) - r, ok := utils.SubWithOverflow(int64(a), int64(b)) - if !ok { - return nil, ErrIntOutOfRange - } - return NewDInt(DInt(r)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDFloat(*left.(*DFloat) - *right.(*DFloat)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - _, err := ExactCtx.Sub(&dd.Decimal, l, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := MustBeDInt(right) - dd := &DDecimal{} - dd.SetInt64(int64(r)) - _, err := ExactCtx.Sub(&dd.Decimal, l, &dd.Decimal) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := MustBeDInt(left) - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - dd.SetInt64(int64(l)) - _, err := ExactCtx.Sub(&dd.Decimal, &dd.Decimal, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Date, RightType: types.Int, ReturnType: types.Date, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - d, err := left.(*DDate).SubDays(int64(MustBeDInt(right))) - if err != nil { - return nil, err - } - return NewDDate(d), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Date, RightType: types.Date, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l, r := left.(*DDate).Date, right.(*DDate).Date - if !l.IsFinite() || !r.IsFinite() { - return nil, pgerror.New(pgcode.DatetimeFieldOverflow, "cannot subtract infinite dates") - } - a := l.PGEpochDays() - b := r.PGEpochDays() - // This can't overflow because they are upconverted from int32 to int64. - return NewDInt(DInt(int64(a) - int64(b))), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Date, RightType: types.Time, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - leftTime, err := left.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := time.Duration(*right.(*DTime)) * time.Microsecond - return MakeDTimestamp(leftTime.Add(-1*t), time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Time, RightType: types.Time, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - t1 := timeofday.TimeOfDay(*left.(*DTime)) - t2 := timeofday.TimeOfDay(*right.(*DTime)) - diff := timeofday.Difference(t1, t2) - return &DInterval{Duration: diff}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Timestamp, RightType: types.Timestamp, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - nanos := left.(*DTimestamp).Sub(right.(*DTimestamp).Time).Nanoseconds() - return &DInterval{Duration: duration.MakeNormalizedDuration(nanos, 0, 0)}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.TimestampTZ, RightType: types.TimestampTZ, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - nanos := left.(*DTimestampTZ).Sub(right.(*DTimestampTZ).Time).Nanoseconds() - return &DInterval{Duration: duration.MakeNormalizedDuration(nanos, 0, 0)}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Timestamp, RightType: types.TimestampTZ, ReturnType: types.Interval, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - // These two quantities aren't directly comparable. Convert the - // TimestampTZ to a timestamp first. - stripped, err := right.(*DTimestampTZ).stripTimeZone(ctx) - if err != nil { - return nil, err - } - nanos := left.(*DTimestamp).Sub(stripped.Time).Nanoseconds() - return &DInterval{Duration: duration.MakeNormalizedDuration(nanos, 0, 0)}, nil - }, Volatility: VolatilityStable, }, &BinOp{ LeftType: types.TimestampTZ, RightType: types.Timestamp, ReturnType: types.Interval, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - // These two quantities aren't directly comparable. Convert the - // TimestampTZ to a timestamp first. - stripped, err := left.(*DTimestampTZ).stripTimeZone(ctx) - if err != nil { - return nil, err - } - nanos := stripped.Sub(right.(*DTimestamp).Time).Nanoseconds() - return &DInterval{Duration: duration.MakeNormalizedDuration(nanos, 0, 0)}, nil - }, Volatility: VolatilityStable, }, &BinOp{ LeftType: types.Time, RightType: types.Interval, ReturnType: types.Time, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - t := timeofday.TimeOfDay(*left.(*DTime)) - return MakeDTime(t.Add(right.(*DInterval).Duration.Mul(-1))), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.TimeTZ, RightType: types.Interval, ReturnType: types.TimeTZ, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - t := left.(*DTimeTZ) - duration := right.(*DInterval).Duration - return NewDTimeTZFromOffset(t.Add(duration.Mul(-1)), t.OffsetSecs), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Timestamp, RightType: types.Interval, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return MakeDTimestamp(duration.Add( - left.(*DTimestamp).Time, right.(*DInterval).Duration.Mul(-1)), time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.TimestampTZ, RightType: types.Interval, ReturnType: types.TimestampTZ, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - t := duration.Add( - left.(*DTimestampTZ).Time.In(ctx.GetLocation()), - right.(*DInterval).Duration.Mul(-1), - ) - return MakeDTimestampTZ(t, time.Microsecond) - }, Volatility: VolatilityStable, }, &BinOp{ LeftType: types.Date, RightType: types.Interval, ReturnType: types.Timestamp, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - leftTime, err := left.(*DDate).ToTime() - if err != nil { - return nil, err - } - t := duration.Add(leftTime, right.(*DInterval).Duration.Mul(-1)) - return MakeDTimestamp(t, time.Microsecond) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Interval, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return &DInterval{Duration: left.(*DInterval).Duration.Sub(right.(*DInterval).Duration)}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Jsonb, RightType: types.String, ReturnType: types.Jsonb, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - j, _, err := left.(*DJSON).JSON.RemoveString(string(MustBeDString(right))) - if err != nil { - return nil, err - } - return &DJSON{j}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Jsonb, RightType: types.Int, ReturnType: types.Jsonb, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - j, _, err := left.(*DJSON).JSON.RemoveIndex(int(MustBeDInt(right))) - if err != nil { - return nil, err - } - return &DJSON{j}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Jsonb, RightType: types.MakeArray(types.String), ReturnType: types.Jsonb, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - j := left.(*DJSON).JSON - arr := *MustBeDArray(right) - - for _, str := range arr.Array { - if str == DNull { - continue - } - var err error - j, _, err = j.RemoveString(string(MustBeDString(str))) - if err != nil { - return nil, err - } - } - return &DJSON{j}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.INet, RightType: types.INet, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - other := MustBeDIPAddr(right).IPAddr - diff, err := ipAddr.SubIPAddr(&other) - return NewDInt(DInt(diff)), err - }, Volatility: VolatilityImmutable, }, &BinOp{ @@ -1152,12 +610,6 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.INet, RightType: types.Int, ReturnType: types.INet, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - i := MustBeDInt(right) - newIPAddr, err := ipAddr.Sub(int64(i)) - return NewDIPAddr(DIPAddr{newIPAddr}), err - }, Volatility: VolatilityImmutable, }, }, @@ -1167,44 +619,18 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - // See Rob Pike's implementation from - // https://groups.google.com/d/msg/golang-nuts/h5oSN5t3Au4/KaNQREhZh0QJ - - a, b := MustBeDInt(left), MustBeDInt(right) - c := a * b - if a == 0 || b == 0 || a == 1 || b == 1 { - // ignore - } else if a == math.MinInt64 || b == math.MinInt64 { - // This test is required to detect math.MinInt64 * -1. - return nil, ErrIntOutOfRange - } else if c/b != a { - return nil, ErrIntOutOfRange - } - return NewDInt(c), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDFloat(*left.(*DFloat) * *right.(*DFloat)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - _, err := ExactCtx.Mul(&dd.Decimal, l, r) - return dd, err - }, Volatility: VolatilityImmutable, }, // The following two overloads are needed because DInt/DInt = DDecimal. Due @@ -1214,94 +640,48 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Decimal, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := MustBeDInt(right) - dd := &DDecimal{} - dd.SetInt64(int64(r)) - _, err := ExactCtx.Mul(&dd.Decimal, l, &dd.Decimal) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := MustBeDInt(left) - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - dd.SetInt64(int64(l)) - _, err := ExactCtx.Mul(&dd.Decimal, &dd.Decimal, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Interval, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return &DInterval{Duration: right.(*DInterval).Duration.Mul(int64(MustBeDInt(left)))}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Int, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return &DInterval{Duration: left.(*DInterval).Duration.Mul(int64(MustBeDInt(right)))}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Float, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - r := float64(*right.(*DFloat)) - return &DInterval{Duration: left.(*DInterval).Duration.MulFloat(r)}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Interval, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := float64(*left.(*DFloat)) - return &DInterval{Duration: right.(*DInterval).Duration.MulFloat(l)}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Interval, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - t, err := l.Float64() - if err != nil { - return nil, err - } - return &DInterval{Duration: right.(*DInterval).Duration.MulFloat(t)}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Decimal, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - r := &right.(*DDecimal).Decimal - t, err := r.Float64() - if err != nil { - return nil, err - } - return &DInterval{Duration: left.(*DInterval).Duration.MulFloat(t)}, nil - }, Volatility: VolatilityImmutable, }, }, @@ -1311,106 +691,42 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - rInt := MustBeDInt(right) - if rInt == 0 { - return nil, ErrDivByZero - } - div := ctx.getTmpDec().SetInt64(int64(rInt)) - dd := &DDecimal{} - dd.SetInt64(int64(MustBeDInt(left))) - _, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, div) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - r := *right.(*DFloat) - if r == 0.0 { - return nil, ErrDivByZero - } - return NewDFloat(*left.(*DFloat) / r), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := &right.(*DDecimal).Decimal - if r.IsZero() { - return nil, ErrDivByZero - } - dd := &DDecimal{} - _, err := DecimalCtx.Quo(&dd.Decimal, l, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := MustBeDInt(right) - if r == 0 { - return nil, ErrDivByZero - } - dd := &DDecimal{} - dd.SetInt64(int64(r)) - _, err := DecimalCtx.Quo(&dd.Decimal, l, &dd.Decimal) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := MustBeDInt(left) - r := &right.(*DDecimal).Decimal - if r.IsZero() { - return nil, ErrDivByZero - } - dd := &DDecimal{} - dd.SetInt64(int64(l)) - _, err := DecimalCtx.Quo(&dd.Decimal, &dd.Decimal, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Int, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - rInt := MustBeDInt(right) - if rInt == 0 { - return nil, ErrDivByZero - } - return &DInterval{Duration: left.(*DInterval).Duration.Div(int64(rInt))}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Interval, RightType: types.Float, ReturnType: types.Interval, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - r := float64(*right.(*DFloat)) - if r == 0.0 { - return nil, ErrDivByZero - } - return &DInterval{Duration: left.(*DInterval).Duration.DivFloat(r)}, nil - }, Volatility: VolatilityImmutable, }, }, @@ -1420,77 +736,30 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - rInt := MustBeDInt(right) - if rInt == 0 { - return nil, ErrDivByZero - } - return NewDInt(MustBeDInt(left) / rInt), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := float64(*left.(*DFloat)) - r := float64(*right.(*DFloat)) - if r == 0.0 { - return nil, ErrDivByZero - } - return NewDFloat(DFloat(math.Trunc(l / r))), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := &right.(*DDecimal).Decimal - if r.IsZero() { - return nil, ErrDivByZero - } - dd := &DDecimal{} - _, err := HighPrecisionCtx.QuoInteger(&dd.Decimal, l, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := MustBeDInt(right) - if r == 0 { - return nil, ErrDivByZero - } - dd := &DDecimal{} - dd.SetInt64(int64(r)) - _, err := HighPrecisionCtx.QuoInteger(&dd.Decimal, l, &dd.Decimal) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := MustBeDInt(left) - r := &right.(*DDecimal).Decimal - if r.IsZero() { - return nil, ErrDivByZero - } - dd := &DDecimal{} - dd.SetInt64(int64(l)) - _, err := HighPrecisionCtx.QuoInteger(&dd.Decimal, &dd.Decimal, r) - return dd, err - }, Volatility: VolatilityImmutable, }, }, @@ -1500,77 +769,30 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - r := MustBeDInt(right) - if r == 0 { - return nil, ErrDivByZero - } - return NewDInt(MustBeDInt(left) % r), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := float64(*left.(*DFloat)) - r := float64(*right.(*DFloat)) - if r == 0.0 { - return nil, ErrDivByZero - } - return NewDFloat(DFloat(math.Mod(l, r))), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := &right.(*DDecimal).Decimal - if r.IsZero() { - return nil, ErrDivByZero - } - dd := &DDecimal{} - _, err := HighPrecisionCtx.Rem(&dd.Decimal, l, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := MustBeDInt(right) - if r == 0 { - return nil, ErrDivByZero - } - dd := &DDecimal{} - dd.SetInt64(int64(r)) - _, err := HighPrecisionCtx.Rem(&dd.Decimal, l, &dd.Decimal) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := MustBeDInt(left) - r := &right.(*DDecimal).Decimal - if r.IsZero() { - return nil, ErrDivByZero - } - dd := &DDecimal{} - dd.SetInt64(int64(l)) - _, err := HighPrecisionCtx.Rem(&dd.Decimal, &dd.Decimal, r) - return dd, err - }, Volatility: VolatilityImmutable, }, }, @@ -1580,44 +802,24 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.String, RightType: types.String, ReturnType: types.String, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDString(string(MustBeDString(left) + MustBeDString(right))), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Bytes, RightType: types.Bytes, ReturnType: types.Bytes, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return NewDBytes(*left.(*DBytes) + *right.(*DBytes)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.VarBit, RightType: types.VarBit, ReturnType: types.VarBit, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - lhs := MustBeDBitArray(left) - rhs := MustBeDBitArray(right) - return &DBitArray{ - BitArray: utils.Concat(lhs.BitArray, rhs.BitArray), - }, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Jsonb, RightType: types.Jsonb, ReturnType: types.Jsonb, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - j, err := MustBeDJSON(left).JSON.Concat(MustBeDJSON(right).JSON) - if err != nil { - return nil, err - } - return &DJSON{j}, nil - }, Volatility: VolatilityImmutable, }, }, @@ -1628,37 +830,18 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - rval := MustBeDInt(right) - if rval < 0 || rval >= 64 { - return nil, ErrShiftArgOutOfRange - } - return NewDInt(MustBeDInt(left) << uint(rval)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.VarBit, RightType: types.Int, ReturnType: types.VarBit, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - lhs := MustBeDBitArray(left) - rhs := MustBeDInt(right) - return &DBitArray{ - BitArray: lhs.BitArray.LeftShiftAny(int64(rhs)), - }, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.INet, RightType: types.INet, ReturnType: types.Bool, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - other := MustBeDIPAddr(right).IPAddr - return MakeDBool(DBool(ipAddr.ContainedBy(&other))), nil - }, Volatility: VolatilityImmutable, }, }, @@ -1668,37 +851,18 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - rval := MustBeDInt(right) - if rval < 0 || rval >= 64 { - return nil, ErrShiftArgOutOfRange - } - return NewDInt(MustBeDInt(left) >> uint(rval)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.VarBit, RightType: types.Int, ReturnType: types.VarBit, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - lhs := MustBeDBitArray(left) - rhs := MustBeDInt(right) - return &DBitArray{ - BitArray: lhs.BitArray.LeftShiftAny(-int64(rhs)), - }, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.INet, RightType: types.INet, ReturnType: types.Bool, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - other := MustBeDIPAddr(right).IPAddr - return MakeDBool(DBool(ipAddr.Contains(&other))), nil - }, Volatility: VolatilityImmutable, }, }, @@ -1708,60 +872,30 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Int, RightType: types.Int, ReturnType: types.Int, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return IntPow(MustBeDInt(left), MustBeDInt(right)) - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Float, RightType: types.Float, ReturnType: types.Float, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - f := math.Pow(float64(*left.(*DFloat)), float64(*right.(*DFloat))) - return NewDFloat(DFloat(f)), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - _, err := DecimalCtx.Pow(&dd.Decimal, l, r) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Decimal, RightType: types.Int, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := &left.(*DDecimal).Decimal - r := MustBeDInt(right) - dd := &DDecimal{} - dd.SetInt64(int64(r)) - _, err := DecimalCtx.Pow(&dd.Decimal, l, &dd.Decimal) - return dd, err - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Int, RightType: types.Decimal, ReturnType: types.Decimal, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - l := MustBeDInt(left) - r := &right.(*DDecimal).Decimal - dd := &DDecimal{} - dd.SetInt64(int64(l)) - _, err := DecimalCtx.Pow(&dd.Decimal, &dd.Decimal, r) - return dd, err - }, Volatility: VolatilityImmutable, }, }, @@ -1771,32 +905,12 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Jsonb, RightType: types.String, ReturnType: types.Jsonb, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - j, err := left.(*DJSON).JSON.FetchValKey(string(MustBeDString(right))) - if err != nil { - return nil, err - } - if j == nil { - return DNull, nil - } - return &DJSON{j}, nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Jsonb, RightType: types.Int, ReturnType: types.Jsonb, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - j, err := left.(*DJSON).JSON.FetchValIdx(int(MustBeDInt(right))) - if err != nil { - return nil, err - } - if j == nil { - return DNull, nil - } - return &DJSON{j}, nil - }, Volatility: VolatilityImmutable, }, }, @@ -1806,9 +920,6 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Jsonb, RightType: types.MakeArray(types.String), ReturnType: types.Jsonb, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - return getJSONPath(*left.(*DJSON), *MustBeDArray(right)) - }, Volatility: VolatilityImmutable, }, }, @@ -1818,46 +929,12 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Jsonb, RightType: types.String, ReturnType: types.String, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - res, err := left.(*DJSON).JSON.FetchValKey(string(MustBeDString(right))) - if err != nil { - return nil, err - } - if res == nil { - return DNull, nil - } - text, err := res.AsText() - if err != nil { - return nil, err - } - if text == nil { - return DNull, nil - } - return NewDString(*text), nil - }, Volatility: VolatilityImmutable, }, &BinOp{ LeftType: types.Jsonb, RightType: types.Int, ReturnType: types.String, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - res, err := left.(*DJSON).JSON.FetchValIdx(int(MustBeDInt(right))) - if err != nil { - return nil, err - } - if res == nil { - return DNull, nil - } - text, err := res.AsText() - if err != nil { - return nil, err - } - if text == nil { - return DNull, nil - } - return NewDString(*text), nil - }, Volatility: VolatilityImmutable, }, }, @@ -1867,43 +944,11 @@ var BinOps = map[BinaryOperator]binOpOverload{ LeftType: types.Jsonb, RightType: types.MakeArray(types.String), ReturnType: types.String, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - res, err := getJSONPath(*left.(*DJSON), *MustBeDArray(right)) - if err != nil { - return nil, err - } - if res == DNull { - return DNull, nil - } - text, err := res.(*DJSON).JSON.AsText() - if err != nil { - return nil, err - } - if text == nil { - return DNull, nil - } - return NewDString(*text), nil - }, Volatility: VolatilityImmutable, }, }, } -// timestampMinusBinOp is the implementation of the subtraction -// between types.TimestampTZ operands. -var timestampMinusBinOp *BinOp - -// TimestampDifference computes the interval difference between two -// TimestampTZ datums. The result is a DInterval. The caller must -// ensure that the arguments are of the proper Datum type. -func TimestampDifference(ctx *EvalContext, start, end Datum) (Datum, error) { - return timestampMinusBinOp.Fn(ctx, start, end) -} - -func init() { - timestampMinusBinOp, _ = BinOps[Minus].lookupImpl(types.TimestampTZ, types.TimestampTZ) -} - // CmpOp is a comparison operator. type CmpOp struct { LeftType *types.T @@ -1913,9 +958,6 @@ type CmpOp struct { // whenever either argument is NULL. NullableArgs bool - // Datum return type is a union between *DBool and dNull. - Fn TwoArgFn - Volatility Volatility types TypeList @@ -1956,26 +998,22 @@ func cmpOpFixups(cmpOps map[ComparisonOperator]cmpOpOverload) map[ComparisonOper cmpOps[EQ] = append(cmpOps[EQ], &CmpOp{ LeftType: types.MakeArray(t), RightType: types.MakeArray(t), - Fn: cmpOpScalarEQFn, Volatility: findVolatility(EQ, t), }) cmpOps[LE] = append(cmpOps[LE], &CmpOp{ LeftType: types.MakeArray(t), RightType: types.MakeArray(t), - Fn: cmpOpScalarLEFn, Volatility: findVolatility(LE, t), }) cmpOps[LT] = append(cmpOps[LT], &CmpOp{ LeftType: types.MakeArray(t), RightType: types.MakeArray(t), - Fn: cmpOpScalarLTFn, Volatility: findVolatility(LT, t), }) cmpOps[IsNotDistinctFrom] = append(cmpOps[IsNotDistinctFrom], &CmpOp{ LeftType: types.MakeArray(t), RightType: types.MakeArray(t), - Fn: cmpOpScalarIsFn, NullableArgs: true, Volatility: findVolatility(IsNotDistinctFrom, t), }) @@ -2006,7 +1044,6 @@ func (o cmpOpOverload) LookupImpl(left, right *types.T) (*CmpOp, bool) { } func makeCmpOpOverload( - fn func(ctx *EvalContext, left, right Datum) (Datum, error), a, b *types.T, nullableArgs bool, v Volatility, @@ -2014,23 +1051,22 @@ func makeCmpOpOverload( return &CmpOp{ LeftType: a, RightType: b, - Fn: fn, NullableArgs: nullableArgs, Volatility: v, } } func makeEqFn(a, b *types.T, v Volatility) *CmpOp { - return makeCmpOpOverload(cmpOpScalarEQFn, a, b, false /* NullableArgs */, v) + return makeCmpOpOverload(a, b, false /* NullableArgs */, v) } func makeLtFn(a, b *types.T, v Volatility) *CmpOp { - return makeCmpOpOverload(cmpOpScalarLTFn, a, b, false /* NullableArgs */, v) + return makeCmpOpOverload(a, b, false /* NullableArgs */, v) } func makeLeFn(a, b *types.T, v Volatility) *CmpOp { - return makeCmpOpOverload(cmpOpScalarLEFn, a, b, false /* NullableArgs */, v) + return makeCmpOpOverload(a, b, false /* NullableArgs */, v) } func makeIsFn(a, b *types.T, v Volatility) *CmpOp { - return makeCmpOpOverload(cmpOpScalarIsFn, a, b, true /* NullableArgs */, v) + return makeCmpOpOverload(a, b, true /* NullableArgs */, v) } // CmpOps contains the comparison operations indexed by operation type. @@ -2083,11 +1119,8 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ // Tuple comparison. &CmpOp{ - LeftType: types.AnyTuple, - RightType: types.AnyTuple, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - return cmpOpTupleFn(ctx, *left.(*DTuple), *right.(*DTuple), EQ), nil - }, + LeftType: types.AnyTuple, + RightType: types.AnyTuple, Volatility: VolatilityImmutable, }, }, @@ -2139,11 +1172,8 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ // Tuple comparison. &CmpOp{ - LeftType: types.AnyTuple, - RightType: types.AnyTuple, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - return cmpOpTupleFn(ctx, *left.(*DTuple), *right.(*DTuple), LT), nil - }, + LeftType: types.AnyTuple, + RightType: types.AnyTuple, Volatility: VolatilityImmutable, }, }, @@ -2195,11 +1225,8 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ // Tuple comparison. &CmpOp{ - LeftType: types.AnyTuple, - RightType: types.AnyTuple, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - return cmpOpTupleFn(ctx, *left.(*DTuple), *right.(*DTuple), LE), nil - }, + LeftType: types.AnyTuple, + RightType: types.AnyTuple, Volatility: VolatilityImmutable, }, }, @@ -2208,7 +1235,6 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ &CmpOp{ LeftType: types.Unknown, RightType: types.Unknown, - Fn: cmpOpScalarIsFn, NullableArgs: true, // Avoids ambiguous comparison error for NULL IS NOT DISTINCT FROM NULL> isPreferred: true, @@ -2264,13 +1290,7 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ LeftType: types.AnyTuple, RightType: types.AnyTuple, NullableArgs: true, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - if left == DNull || right == DNull { - return MakeDBool(left == DNull && right == DNull), nil - } - return cmpOpTupleFn(ctx, *left.(*DTuple), *right.(*DTuple), IsNotDistinctFrom), nil - }, - Volatility: VolatilityImmutable, + Volatility: VolatilityImmutable, }, }, @@ -2302,34 +1322,24 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ Like: { &CmpOp{ - LeftType: types.String, - RightType: types.String, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - return matchLike(ctx, left, right, false) - }, + LeftType: types.String, + RightType: types.String, Volatility: VolatilityLeakProof, }, }, ILike: { &CmpOp{ - LeftType: types.String, - RightType: types.String, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - return matchLike(ctx, left, right, true) - }, + LeftType: types.String, + RightType: types.String, Volatility: VolatilityLeakProof, }, }, SimilarTo: { &CmpOp{ - LeftType: types.String, - RightType: types.String, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - key := similarToKey{s: string(MustBeDString(right)), escape: '\\'} - return matchRegexpWithKey(ctx, left, key) - }, + LeftType: types.String, + RightType: types.String, Volatility: VolatilityLeakProof, }, }, @@ -2337,12 +1347,8 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ RegMatch: append( cmpOpOverload{ &CmpOp{ - LeftType: types.String, - RightType: types.String, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - key := regexpKey{s: string(MustBeDString(right)), caseInsensitive: false} - return matchRegexpWithKey(ctx, left, key) - }, + LeftType: types.String, + RightType: types.String, Volatility: VolatilityImmutable, }, }, @@ -2355,165 +1361,71 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ RegIMatch: { &CmpOp{ - LeftType: types.String, - RightType: types.String, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - key := regexpKey{s: string(MustBeDString(right)), caseInsensitive: true} - return matchRegexpWithKey(ctx, left, key) - }, + LeftType: types.String, + RightType: types.String, Volatility: VolatilityImmutable, }, }, JSONExists: { &CmpOp{ - LeftType: types.Jsonb, - RightType: types.String, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - e, err := left.(*DJSON).JSON.Exists(string(MustBeDString(right))) - if err != nil { - return nil, err - } - if e { - return DBoolTrue, nil - } - return DBoolFalse, nil - }, + LeftType: types.Jsonb, + RightType: types.String, Volatility: VolatilityImmutable, }, }, JSONSomeExists: { &CmpOp{ - LeftType: types.Jsonb, - RightType: types.StringArray, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - // TODO(justin): this can be optimized. - for _, k := range MustBeDArray(right).Array { - if k == DNull { - continue - } - e, err := left.(*DJSON).JSON.Exists(string(MustBeDString(k))) - if err != nil { - return nil, err - } - if e { - return DBoolTrue, nil - } - } - return DBoolFalse, nil - }, + LeftType: types.Jsonb, + RightType: types.StringArray, Volatility: VolatilityImmutable, }, }, JSONAllExists: { &CmpOp{ - LeftType: types.Jsonb, - RightType: types.StringArray, - Fn: func(_ *EvalContext, left Datum, right Datum) (Datum, error) { - // TODO(justin): this can be optimized. - for _, k := range MustBeDArray(right).Array { - if k == DNull { - continue - } - e, err := left.(*DJSON).JSON.Exists(string(MustBeDString(k))) - if err != nil { - return nil, err - } - if !e { - return DBoolFalse, nil - } - } - return DBoolTrue, nil - }, + LeftType: types.Jsonb, + RightType: types.StringArray, Volatility: VolatilityImmutable, }, }, Contains: { &CmpOp{ - LeftType: types.AnyArray, - RightType: types.AnyArray, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - haystack := MustBeDArray(left) - needles := MustBeDArray(right) - return ArrayContains(ctx, haystack, needles) - }, + LeftType: types.AnyArray, + RightType: types.AnyArray, Volatility: VolatilityImmutable, }, &CmpOp{ - LeftType: types.Jsonb, - RightType: types.Jsonb, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - c, err := json.Contains(left.(*DJSON).JSON, right.(*DJSON).JSON) - if err != nil { - return nil, err - } - return MakeDBool(DBool(c)), nil - }, + LeftType: types.Jsonb, + RightType: types.Jsonb, Volatility: VolatilityImmutable, }, }, ContainedBy: { &CmpOp{ - LeftType: types.AnyArray, - RightType: types.AnyArray, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - needles := MustBeDArray(left) - haystack := MustBeDArray(right) - return ArrayContains(ctx, haystack, needles) - }, + LeftType: types.AnyArray, + RightType: types.AnyArray, Volatility: VolatilityImmutable, }, &CmpOp{ - LeftType: types.Jsonb, - RightType: types.Jsonb, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - c, err := json.Contains(right.(*DJSON).JSON, left.(*DJSON).JSON) - if err != nil { - return nil, err - } - return MakeDBool(DBool(c)), nil - }, + LeftType: types.Jsonb, + RightType: types.Jsonb, Volatility: VolatilityImmutable, }, }, Overlaps: append( cmpOpOverload{ &CmpOp{ - LeftType: types.AnyArray, - RightType: types.AnyArray, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - array := MustBeDArray(left) - other := MustBeDArray(right) - if !array.ParamTyp.Equivalent(other.ParamTyp) { - return nil, pgerror.New(pgcode.DatatypeMismatch, "cannot compare arrays with different element types") - } - for _, needle := range array.Array { - // Nulls don't compare to each other in && syntax. - if needle == DNull { - continue - } - for _, hay := range other.Array { - if needle.Compare(ctx, hay) == 0 { - return DBoolTrue, nil - } - } - } - return DBoolFalse, nil - }, + LeftType: types.AnyArray, + RightType: types.AnyArray, Volatility: VolatilityImmutable, }, &CmpOp{ - LeftType: types.INet, - RightType: types.INet, - Fn: func(_ *EvalContext, left, right Datum) (Datum, error) { - ipAddr := MustBeDIPAddr(left).IPAddr - other := MustBeDIPAddr(right).IPAddr - return MakeDBool(DBool(ipAddr.ContainsOrContainedBy(&other))), nil - }, + LeftType: types.INet, + RightType: types.INet, Volatility: VolatilityImmutable, }, }, @@ -2528,51 +1440,23 @@ var CmpOps = cmpOpFixups(map[ComparisonOperator]cmpOpOverload{ func makeBox2DComparisonOperators(op func(lhs, rhs *geo.CartesianBoundingBox) bool) cmpOpOverload { return cmpOpOverload{ &CmpOp{ - LeftType: types.Box2D, - RightType: types.Box2D, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - ret := op( - &MustBeDBox2D(left).CartesianBoundingBox, - &MustBeDBox2D(right).CartesianBoundingBox, - ) - return MakeDBool(DBool(ret)), nil - }, + LeftType: types.Box2D, + RightType: types.Box2D, Volatility: VolatilityImmutable, }, &CmpOp{ - LeftType: types.Box2D, - RightType: types.Geometry, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - ret := op( - &MustBeDBox2D(left).CartesianBoundingBox, - MustBeDGeometry(right).CartesianBoundingBox(), - ) - return MakeDBool(DBool(ret)), nil - }, + LeftType: types.Box2D, + RightType: types.Geometry, Volatility: VolatilityImmutable, }, &CmpOp{ - LeftType: types.Geometry, - RightType: types.Box2D, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - ret := op( - MustBeDGeometry(left).CartesianBoundingBox(), - &MustBeDBox2D(right).CartesianBoundingBox, - ) - return MakeDBool(DBool(ret)), nil - }, + LeftType: types.Geometry, + RightType: types.Box2D, Volatility: VolatilityImmutable, }, &CmpOp{ - LeftType: types.Geometry, - RightType: types.Geometry, - Fn: func(ctx *EvalContext, left Datum, right Datum) (Datum, error) { - ret := op( - MustBeDGeometry(left).CartesianBoundingBox(), - MustBeDGeometry(right).CartesianBoundingBox(), - ) - return MakeDBool(DBool(ret)), nil - }, + LeftType: types.Geometry, + RightType: types.Geometry, Volatility: VolatilityImmutable, }, } @@ -2607,350 +1491,51 @@ func boolFromCmp(cmp int, op ComparisonOperator) *DBool { } } -func cmpOpScalarFn(ctx *EvalContext, left, right Datum, op ComparisonOperator) Datum { - // Before deferring to the Datum.Compare method, check for values that should - // be handled differently during SQL comparison evaluation than they should when - // ordering Datum values. - if left == DNull || right == DNull { - switch op { - case IsNotDistinctFrom: - return MakeDBool((left == DNull) == (right == DNull)) - - default: - // If either Datum is NULL, the result of the comparison is NULL. - return DNull - } +func makeEvalTupleIn(typ *types.T, v Volatility) *CmpOp { + return &CmpOp{ + LeftType: typ, + RightType: types.AnyTuple, + NullableArgs: true, + Volatility: v, } - cmp := left.Compare(ctx, right) - return boolFromCmp(cmp, op) } -func cmpOpScalarEQFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, EQ), nil -} -func cmpOpScalarLTFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, LT), nil -} -func cmpOpScalarLEFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, LE), nil +// MultipleResultsError is returned by QueryRow when more than one result is +// encountered. +type MultipleResultsError struct { + SQL string // the query that produced this error } -func cmpOpScalarIsFn(ctx *EvalContext, left, right Datum) (Datum, error) { - return cmpOpScalarFn(ctx, left, right, IsNotDistinctFrom), nil + +func (e *MultipleResultsError) Error() string { + return fmt.Sprintf("%s: unexpected multiple results", e.SQL) } -func cmpOpTupleFn(ctx *EvalContext, left, right DTuple, op ComparisonOperator) Datum { - cmp := 0 - sawNull := false - for i, leftElem := range left.D { - rightElem := right.D[i] - // Like with cmpOpScalarFn, check for values that need to be handled - // differently than when ordering Datums. - if leftElem == DNull || rightElem == DNull { - switch op { - case EQ: - // If either Datum is NULL and the op is EQ, we continue the - // comparison and the result is only NULL if the other (non-NULL) - // elements are equal. This is because NULL is thought of as "unknown", - // so a NULL equality comparison does not prevent the equality from - // being proven false, but does prevent it from being proven true. - sawNull = true +// EvalDatabase consists of functions that reference the session database +// and is to be used from EvalContext. +type EvalDatabase interface { + // ParseQualifiedTableName parses a SQL string of the form + // `[ database_name . ] [ schema_name . ] table_name`. + // NB: this is deprecated! Use parser.ParseQualifiedTableName when possible. + ParseQualifiedTableName(sql string) (*TableName, error) - case IsNotDistinctFrom: - // For IS NOT DISTINCT FROM, NULLs are "equal". - if leftElem != DNull || rightElem != DNull { - return DBoolFalse - } + // ResolveTableName expands the given table name and + // makes it point to a valid object. + // If the database name is not given, it uses the search path to find it, and + // sets it on the returned TableName. + // It returns the ID of the resolved table, and an error if the table doesn't exist. + ResolveTableName(ctx context.Context, tn *TableName) (ID, error) - default: - // If either Datum is NULL and the op is not EQ or IS NOT DISTINCT FROM, - // we short-circuit the evaluation and the result of the comparison is - // NULL. This is because NULL is thought of as "unknown" and tuple - // inequality is defined lexicographically, so once a NULL comparison is - // seen, the result of the entire tuple comparison is unknown. - return DNull - } - } else { - cmp = leftElem.Compare(ctx, rightElem) - if cmp != 0 { - break - } - } - } - b := boolFromCmp(cmp, op) - if b == DBoolTrue && sawNull { - // The op is EQ and all non-NULL elements are equal, but we saw at least - // one NULL element. Since NULL comparisons are treated as unknown, the - // result of the comparison becomes unknown (NULL). - return DNull - } - return b + // LookupSchema looks up the schema with the given name in the given + // database. + LookupSchema(ctx context.Context, dbName, scName string) (found bool, scMeta SchemaMeta, err error) } -func makeEvalTupleIn(typ *types.T, v Volatility) *CmpOp { - return &CmpOp{ - LeftType: typ, - RightType: types.AnyTuple, - Fn: func(ctx *EvalContext, arg, values Datum) (Datum, error) { - vtuple := values.(*DTuple) - // If the tuple was sorted during normalization, we can perform an - // efficient binary search to find if the arg is in the tuple (as - // long as the arg doesn't contain any NULLs). - if len(vtuple.D) == 0 { - // If the rhs tuple is empty, the result is always false (even if arg is - // or contains NULL). - return DBoolFalse, nil - } - if arg == DNull { - return DNull, nil - } - argTuple, argIsTuple := arg.(*DTuple) - if vtuple.Sorted() && !(argIsTuple && argTuple.ContainsNull()) { - // The right-hand tuple is already sorted and contains no NULLs, and the - // left side is not NULL (e.g. `NULL IN (1, 2)`) or a tuple that - // contains NULL (e.g. `(1, NULL) IN ((1, 2), (3, 4))`). - // - // We can use binary search to make a determination in this case. This - // is the common case when tuples don't contain NULLs. - _, result := vtuple.SearchSorted(ctx, arg) - return MakeDBool(DBool(result)), nil - } - - sawNull := false - if !argIsTuple { - // The left-hand side is not a tuple, e.g. `1 IN (1, 2)`. - for _, val := range vtuple.D { - if val == DNull { - sawNull = true - } else if val.Compare(ctx, arg) == 0 { - return DBoolTrue, nil - } - } - } else { - // The left-hand side is a tuple, e.g. `(1, 2) IN ((1, 2), (3, 4))`. - for _, val := range vtuple.D { - if val == DNull { - // We allow for a null value to be in the list of tuples, so we - // need to check that upfront. - sawNull = true - } else { - // Use the EQ function which properly handles NULLs. - if res := cmpOpTupleFn(ctx, *argTuple, *val.(*DTuple), EQ); res == DNull { - sawNull = true - } else if res == DBoolTrue { - return DBoolTrue, nil - } - } - } - } - if sawNull { - return DNull, nil - } - return DBoolFalse, nil - }, - NullableArgs: true, - Volatility: v, - } -} - -// evalDatumsCmp evaluates Datums (slice of Datum) using the provided -// sub-operator type (ANY/SOME, ALL) and its CmpOp with the left Datum. -// It returns the result of the ANY/SOME/ALL predicate. -// -// A NULL result is returned if there exists a NULL element and: -// ANY/SOME: no comparisons evaluate to true -// ALL: no comparisons evaluate to false -// -// For example, given 1 < ANY (SELECT * FROM generate_series(1,3)) -// (right is a DTuple), evalTupleCmp would be called with: -// evalDatumsCmp(ctx, LT, Any, CmpOp(LT, leftType, rightParamType), leftDatum, rightTuple.D). -// Similarly, given 1 < ANY (ARRAY[1, 2, 3]) (right is a DArray), -// evalArrayCmp would be called with: -// evalDatumsCmp(ctx, LT, Any, CmpOp(LT, leftType, rightParamType), leftDatum, rightArray.Array). -func evalDatumsCmp( - ctx *EvalContext, op, subOp ComparisonOperator, fn *CmpOp, left Datum, right Datums, -) (Datum, error) { - all := op == All - any := !all - sawNull := false - for _, elem := range right { - if elem == DNull { - sawNull = true - continue - } - - _, newLeft, newRight, _, not := FoldComparisonExpr(subOp, left, elem) - d, err := fn.Fn(ctx, newLeft.(Datum), newRight.(Datum)) - if err != nil { - return nil, err - } - if d == DNull { - sawNull = true - continue - } - - b := d.(*DBool) - res := *b != DBool(not) - if any && res { - return DBoolTrue, nil - } else if all && !res { - return DBoolFalse, nil - } - } - - if sawNull { - // If the right-hand array contains any null elements and no [false,true] - // comparison result is obtained, the result of [ALL,ANY] will be null. - return DNull, nil - } - - if all { - // ALL are true && !sawNull - return DBoolTrue, nil - } - // ANY is false && !sawNull - return DBoolFalse, nil -} - -// MatchLikeEscape matches 'unescaped' with 'pattern' using custom escape character 'escape' which -// must be either empty (which disables the escape mechanism) or a single unicode character. -func MatchLikeEscape( - ctx *EvalContext, unescaped, pattern, escape string, caseInsensitive bool, -) (Datum, error) { - var escapeRune rune - if len(escape) > 0 { - var width int - escapeRune, width = utf8.DecodeRuneInString(escape) - if len(escape) > width { - return DBoolFalse, pgerror.Newf(pgcode.InvalidEscapeSequence, "invalid escape string") - } - } - - if len(unescaped) == 0 { - // An empty string only matches with an empty pattern or a pattern - // consisting only of '%' (if this wildcard is not used as a custom escape - // character). To match PostgreSQL's behavior, we have a special handling - // of this case. - for _, c := range pattern { - if c != '%' || (c == '%' && escape == `%`) { - return DBoolFalse, nil - } - } - return DBoolTrue, nil - } - - like, err := optimizedLikeFunc(pattern, caseInsensitive, escapeRune) - if err != nil { - return DBoolFalse, pgerror.Newf( - pgcode.InvalidRegularExpression, "LIKE regexp compilation failed: %v", err) - } - - if like == nil { - re, err := ConvertLikeToRegexp(ctx, pattern, caseInsensitive, escapeRune) - if err != nil { - return DBoolFalse, err - } - like = func(s string) (bool, error) { - return re.MatchString(s), nil - } - } - matches, err := like(unescaped) - return MakeDBool(DBool(matches)), err -} - -// ConvertLikeToRegexp compiles the specified LIKE pattern as an equivalent -// regular expression. -func ConvertLikeToRegexp( - ctx *EvalContext, pattern string, caseInsensitive bool, escape rune, -) (*regexp.Regexp, error) { - key := likeKey{s: pattern, caseInsensitive: caseInsensitive, escape: escape} - re, err := ctx.ReCache.GetRegexp(key) - if err != nil { - return nil, pgerror.Newf( - pgcode.InvalidRegularExpression, "LIKE regexp compilation failed: %v", err) - } - return re, nil -} - -func matchLike(ctx *EvalContext, left, right Datum, caseInsensitive bool) (Datum, error) { - if left == DNull || right == DNull { - return DNull, nil - } - s, pattern := string(MustBeDString(left)), string(MustBeDString(right)) - if len(s) == 0 { - // An empty string only matches with an empty pattern or a pattern - // consisting only of '%'. To match PostgreSQL's behavior, we have a - // special handling of this case. - for _, c := range pattern { - if c != '%' { - return DBoolFalse, nil - } - } - return DBoolTrue, nil - } - - like, err := optimizedLikeFunc(pattern, caseInsensitive, '\\') - if err != nil { - return DBoolFalse, pgerror.Newf( - pgcode.InvalidRegularExpression, "LIKE regexp compilation failed: %v", err) - } - - if like == nil { - re, err := ConvertLikeToRegexp(ctx, pattern, caseInsensitive, '\\') - if err != nil { - return DBoolFalse, err - } - like = func(s string) (bool, error) { - return re.MatchString(s), nil - } - } - matches, err := like(s) - return MakeDBool(DBool(matches)), err -} - -func matchRegexpWithKey(ctx *EvalContext, str Datum, key RegexpCacheKey) (Datum, error) { - re, err := ctx.ReCache.GetRegexp(key) - if err != nil { - return DBoolFalse, err - } - return MakeDBool(DBool(re.MatchString(string(MustBeDString(str))))), nil -} - -// MultipleResultsError is returned by QueryRow when more than one result is -// encountered. -type MultipleResultsError struct { - SQL string // the query that produced this error -} - -func (e *MultipleResultsError) Error() string { - return fmt.Sprintf("%s: unexpected multiple results", e.SQL) -} - -// EvalDatabase consists of functions that reference the session database -// and is to be used from EvalContext. -type EvalDatabase interface { - // ParseQualifiedTableName parses a SQL string of the form - // `[ database_name . ] [ schema_name . ] table_name`. - // NB: this is deprecated! Use parser.ParseQualifiedTableName when possible. - ParseQualifiedTableName(sql string) (*TableName, error) - - // ResolveTableName expands the given table name and - // makes it point to a valid object. - // If the database name is not given, it uses the search path to find it, and - // sets it on the returned TableName. - // It returns the ID of the resolved table, and an error if the table doesn't exist. - ResolveTableName(ctx context.Context, tn *TableName) (ID, error) - - // LookupSchema looks up the schema with the given name in the given - // database. - LookupSchema(ctx context.Context, dbName, scName string) (found bool, scMeta SchemaMeta, err error) -} - -// EvalPlanner is a limited planner that can be used from EvalContext. -type EvalPlanner interface { - EvalDatabase - TypeReferenceResolver - // ParseType parses a column type. - ParseType(sql string) (*types.T, error) +// EvalPlanner is a limited planner that can be used from EvalContext. +type EvalPlanner interface { + EvalDatabase + TypeReferenceResolver + // ParseType parses a column type. + ParseType(sql string) (*types.T, error) // EvalSubquery returns the Datum for the given subquery node. EvalSubquery(expr *Subquery) (Datum, error) @@ -2986,31 +1571,6 @@ type ClientNoticeSender interface { SendClientNotice(ctx context.Context, notice error) } -// InternalExecutor is a subset of sqlutil.InternalExecutor (which, in turn, is -// implemented by sql.InternalExecutor) used by this sem/tree package which -// can't even import sqlutil. -// -// Note that the functions offered here should be avoided when possible. They -// execute the query as root if an user hadn't been previously set on the -// executor through SetSessionData(). These functions are deprecated in -// sql.InternalExecutor in favor of a safer interface. Unfortunately, those -// safer functions cannot be exposed through this interface because they depend -// on sqlbase, and this package cannot import sqlbase. When possible, downcast -// this to sqlutil.InternalExecutor or sql.InternalExecutor, and use the -// alternatives. -type InternalExecutor interface { - // Query is part of the sqlutil.InternalExecutor interface. - Query( - ctx context.Context, opName string, txn *kv.Txn, - stmt string, qargs ...interface{}, - ) ([]Datums, error) - - // QueryRow is part of the sqlutil.InternalExecutor interface. - QueryRow( - ctx context.Context, opName string, txn *kv.Txn, stmt string, qargs ...interface{}, - ) (Datums, error) -} - // PrivilegedAccessor gives access to certain queries that would otherwise // require someone with RootUser access to query a given data source. // It is defined independently to prevent a circular dependency on sql, tree and sqlbase. @@ -3090,387 +1650,6 @@ type EvalContextTestingKnobs struct { // cost of each expression in the query tree for the purpose of creating // alternate query plans in the optimizer. OptimizerCostPerturbation float64 - - CallbackGenerators map[string]*CallbackValueGenerator -} - -// EvalContext defines the context in which to evaluate an expression, allowing -// the retrieval of state such as the node ID or statement start time. -// -// ATTENTION: Some fields from this struct (particularly, but not exclusively, -// from SessionData) are also represented in execinfrapb.EvalContext. Whenever -// something that affects DistSQL execution is added, it needs to be marshaled -// through that proto too. -// TODO(andrei): remove or limit the duplication. -// -// NOTE(andrei): EvalContext is dusty; it started as a collection of fields -// needed by expression evaluation, but it has grown quite large; some of the -// things in it don't seem to belong in this low-level package (e.g. Planner). -// In the sql package it is embedded by extendedEvalContext, which adds some -// more fields from the sql package. Through that extendedEvalContext, this -// struct now generally used by planNodes. -type EvalContext struct { - // Session variables. This is a read-only copy of the values owned by the - // Session. - SessionData *sessiondata.SessionData - // TxnState is a string representation of the current transactional state. - TxnState string - // TxnReadOnly specifies if the current transaction is read-only. - TxnReadOnly bool - TxnImplicit bool - - // The statement timestamp. May be different for every statement. - // Used for statement_timestamp(). - StmtTimestamp time.Time - // The transaction timestamp. Needs to stay stable for the lifetime - // of a transaction. Used for now(), current_timestamp(), - // transaction_timestamp() and the like. - TxnTimestamp time.Time - - // Placeholders relates placeholder names to their type and, later, value. - // This pointer should always be set to the location of the PlaceholderInfo - // in the corresponding SemaContext during normal execution. Placeholders are - // available during Eval to permit lookup of a particular placeholder's - // underlying datum, if available. - Placeholders *PlaceholderInfo - - // Annotations augments the AST with extra information. This pointer should - // always be set to the location of the Annotations in the corresponding - // SemaContext. - Annotations *Annotations - - // IVarContainer is used to evaluate IndexedVars. - IVarContainer IndexedVarContainer - // iVarContainerStack is used when we swap out IVarContainers in order to - // evaluate an intermediate expression. This keeps track of those which we - // need to restore once we finish evaluating it. - iVarContainerStack []IndexedVarContainer - - // Context holds the context in which the expression is evaluated. - Context context.Context - - // InternalExecutor gives access to an executor to be used for running - // "internal" statements. It may seem bizarre that "expression evaluation" may - // need to run a statement, and yet many builtin functions do it. - // Note that the executor will be "session-bound" - it will inherit session - // variables from a parent session. - InternalExecutor InternalExecutor - - Planner EvalPlanner - - PrivilegedAccessor PrivilegedAccessor - - SessionAccessor EvalSessionAccessor - - ClientNoticeSender ClientNoticeSender - - Sequence SequenceOperators - - Tenant TenantOperator - - // The transaction in which the statement is executing. - Txn *kv.Txn - - ReCache *RegexpCache - tmpDec apd.Decimal - - // TODO(mjibson): remove prepareOnly in favor of a 2-step prepare-exec solution - // that is also able to save the plan to skip work during the exec step. - PrepareOnly bool - - // SkipNormalize indicates whether expressions should be normalized - // (false) or not (true). It is set to true conditionally by - // EXPLAIN(TYPES[, NORMALIZE]). - SkipNormalize bool - - CollationEnv CollationEnvironment - - TestingKnobs EvalContextTestingKnobs -} - -// Copy returns a deep copy of ctx. -func (ctx *EvalContext) Copy() *EvalContext { - ctxCopy := *ctx - ctxCopy.iVarContainerStack = make([]IndexedVarContainer, len(ctx.iVarContainerStack), cap(ctx.iVarContainerStack)) - copy(ctxCopy.iVarContainerStack, ctx.iVarContainerStack) - return &ctxCopy -} - -// PushIVarContainer replaces the current IVarContainer with a different one - -// pushing the current one onto a stack to be replaced later once -// PopIVarContainer is called. -func (ctx *EvalContext) PushIVarContainer(c IndexedVarContainer) { - ctx.iVarContainerStack = append(ctx.iVarContainerStack, ctx.IVarContainer) - ctx.IVarContainer = c -} - -// PopIVarContainer discards the current IVarContainer on the EvalContext, -// replacing it with an older one. -func (ctx *EvalContext) PopIVarContainer() { - ctx.IVarContainer = ctx.iVarContainerStack[len(ctx.iVarContainerStack)-1] - ctx.iVarContainerStack = ctx.iVarContainerStack[:len(ctx.iVarContainerStack)-1] -} - -// Stop closes out the EvalContext and must be called once it is no longer in use. -func (ctx *EvalContext) Stop(c context.Context) {} - -// GetStmtTimestamp retrieves the current statement timestamp as per -// the evaluation context. The timestamp is guaranteed to be nonzero. -func (ctx *EvalContext) GetStmtTimestamp() time.Time { - // TODO(knz): a zero timestamp should never be read, even during - // Prepare. This will need to be addressed. - if !ctx.PrepareOnly && ctx.StmtTimestamp.IsZero() { - panic(errors.AssertionFailedf("zero statement timestamp in EvalContext")) - } - return ctx.StmtTimestamp -} - -// HasPlaceholders returns true if this EvalContext's placeholders have been -// assigned. Will be false during Prepare. -func (ctx *EvalContext) HasPlaceholders() bool { - return ctx.Placeholders != nil -} - -// TimestampToDecimal converts the logical timestamp into a decimal -// value with the number of nanoseconds in the integer part and the -// logical counter in the decimal part. -func TimestampToDecimal(ts hlc.Timestamp) apd.Decimal { - // Compute Walltime * 10^10 + Logical. - // We need 10 decimals for the Logical field because its maximum - // value is 4294967295 (2^32-1), a value with 10 decimal digits. - var res apd.Decimal - val := &res.Coeff - val.SetInt64(ts.WallTime) - val.Mul(val, big10E10) - val.Add(val, big.NewInt(int64(ts.Logical))) - - // val must be positive. If it was set to a negative value above, - // transfer the sign to res.Negative. - res.Negative = val.Sign() < 0 - val.Abs(val) - - // Shift 10 decimals to the right, so that the logical - // field appears as fractional part. - res.Exponent = -10 - return res -} - -// DecimalToInexactDTimestamp is the inverse of TimestampToDecimal. It converts -// a decimal constructed from an hlc.Timestamp into an approximate DTimestamp -// containing the walltime of the hlc.Timestamp. -func DecimalToInexactDTimestamp(d *DDecimal) (*DTimestamp, error) { - var coef big.Int - coef.Set(&d.Decimal.Coeff) - // The physical portion of the HLC is stored shifted up by 10^10, so shift - // it down and clear out the logical component. - coef.Div(&coef, big10E10) - if !coef.IsInt64() { - return nil, pgerror.Newf(pgcode.DatetimeFieldOverflow, "timestamp value out of range: %s", d.String()) - } - return TimestampToInexactDTimestamp(hlc.Timestamp{WallTime: coef.Int64()}), nil -} - -// TimestampToDecimalDatum is the same as TimestampToDecimal, but -// returns a datum. -func TimestampToDecimalDatum(ts hlc.Timestamp) *DDecimal { - res := TimestampToDecimal(ts) - return &DDecimal{ - Decimal: res, - } -} - -// TimestampToInexactDTimestamp converts the logical timestamp into an -// inexact DTimestamp by dropping the logical counter and using the wall -// time at the microsecond precision. -func TimestampToInexactDTimestamp(ts hlc.Timestamp) *DTimestamp { - return MustMakeDTimestamp(time.Unix(0, ts.WallTime).UTC(), time.Microsecond) -} - -// GetRelativeParseTime implements ParseTimeContext. -func (ctx *EvalContext) GetRelativeParseTime() time.Time { - ret := ctx.TxnTimestamp - if ret.IsZero() { - ret = time.Now() - } - return ret.In(ctx.GetLocation()) -} - -// GetTxnTimestamp retrieves the current transaction timestamp as per -// the evaluation context. The timestamp is guaranteed to be nonzero. -func (ctx *EvalContext) GetTxnTimestamp(precision time.Duration) *DTimestampTZ { - // TODO(knz): a zero timestamp should never be read, even during - // Prepare. This will need to be addressed. - if !ctx.PrepareOnly && ctx.TxnTimestamp.IsZero() { - panic(errors.AssertionFailedf("zero transaction timestamp in EvalContext")) - } - return MustMakeDTimestampTZ(ctx.GetRelativeParseTime(), precision) -} - -// GetTxnTimestampNoZone retrieves the current transaction timestamp as per -// the evaluation context. The timestamp is guaranteed to be nonzero. -func (ctx *EvalContext) GetTxnTimestampNoZone(precision time.Duration) *DTimestamp { - // TODO(knz): a zero timestamp should never be read, even during - // Prepare. This will need to be addressed. - if !ctx.PrepareOnly && ctx.TxnTimestamp.IsZero() { - panic(errors.AssertionFailedf("zero transaction timestamp in EvalContext")) - } - // Move the time to UTC, but keeping the location's time. - t := ctx.GetRelativeParseTime() - _, offsetSecs := t.Zone() - return MustMakeDTimestamp(t.Add(time.Second*time.Duration(offsetSecs)).In(time.UTC), precision) -} - -// GetTxnTime retrieves the current transaction time as per -// the evaluation context. -func (ctx *EvalContext) GetTxnTime(precision time.Duration) *DTimeTZ { - // TODO(knz): a zero timestamp should never be read, even during - // Prepare. This will need to be addressed. - if !ctx.PrepareOnly && ctx.TxnTimestamp.IsZero() { - panic(errors.AssertionFailedf("zero transaction timestamp in EvalContext")) - } - return NewDTimeTZFromTime(ctx.GetRelativeParseTime().Round(precision)) -} - -// GetTxnTimeNoZone retrieves the current transaction time as per -// the evaluation context. -func (ctx *EvalContext) GetTxnTimeNoZone(precision time.Duration) *DTime { - // TODO(knz): a zero timestamp should never be read, even during - // Prepare. This will need to be addressed. - if !ctx.PrepareOnly && ctx.TxnTimestamp.IsZero() { - panic(errors.AssertionFailedf("zero transaction timestamp in EvalContext")) - } - return MakeDTime(timeofday.FromTime(ctx.GetRelativeParseTime().Round(precision))) -} - -// SetTxnTimestamp sets the corresponding timestamp in the EvalContext. -func (ctx *EvalContext) SetTxnTimestamp(ts time.Time) { - ctx.TxnTimestamp = ts -} - -// SetStmtTimestamp sets the corresponding timestamp in the EvalContext. -func (ctx *EvalContext) SetStmtTimestamp(ts time.Time) { - ctx.StmtTimestamp = ts -} - -// GetLocation returns the session timezone. -func (ctx *EvalContext) GetLocation() *time.Location { - if ctx.SessionData == nil || ctx.SessionData.DataConversion.Location == nil { - return time.UTC - } - return ctx.SessionData.DataConversion.Location -} - -// Ctx returns the session's context. -func (ctx *EvalContext) Ctx() context.Context { - return ctx.Context -} - -func (ctx *EvalContext) getTmpDec() *apd.Decimal { - return &ctx.tmpDec -} - -// Eval implements the TypedExpr interface. -func (expr *AndExpr) Eval(ctx *EvalContext) (Datum, error) { - left, err := expr.Left.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if left != DNull { - if v, err := GetBool(left); err != nil { - return nil, err - } else if !v { - return left, nil - } - } - right, err := expr.Right.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if right == DNull { - return DNull, nil - } - if v, err := GetBool(right); err != nil { - return nil, err - } else if !v { - return right, nil - } - return left, nil -} - -// Eval implements the TypedExpr interface. -func (expr *BinaryExpr) Eval(ctx *EvalContext) (Datum, error) { - left, err := expr.Left.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if left == DNull && !expr.Fn.NullableArgs { - return DNull, nil - } - right, err := expr.Right.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if right == DNull && !expr.Fn.NullableArgs { - return DNull, nil - } - res, err := expr.Fn.Fn(ctx, left, right) - if err != nil { - return nil, err - } - if ctx.TestingKnobs.AssertBinaryExprReturnTypes { - if err := ensureExpectedType(expr.Fn.ReturnType, res); err != nil { - return nil, errors.NewAssertionErrorWithWrappedErrf(err, - "binary op %q", expr) - } - } - return res, err -} - -// Eval implements the TypedExpr interface. -func (expr *CaseExpr) Eval(ctx *EvalContext) (Datum, error) { - if expr.Expr != nil { - // CASE WHEN THEN ... - // - // For each "when" expression we compare for equality to . - val, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - - for _, when := range expr.Whens { - arg, err := when.Cond.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - d, err := evalComparison(ctx, EQ, val, arg) - if err != nil { - return nil, err - } - if v, err := GetBool(d); err != nil { - return nil, err - } else if v { - return when.Val.(TypedExpr).Eval(ctx) - } - } - } else { - // CASE WHEN THEN ... - for _, when := range expr.Whens { - d, err := when.Cond.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if v, err := GetBool(d); err != nil { - return nil, err - } else if v { - return when.Val.(TypedExpr).Eval(ctx) - } - } - } - - if expr.Else != nil { - return expr.Else.(TypedExpr).Eval(ctx) - } - return DNull, nil } // pgSignatureRegexp matches a Postgres function type signature, capturing the @@ -3499,1673 +1678,46 @@ var regTypeInfos = map[oid.Oid]regTypeInfo{ oid.T_regnamespace: {"pg_namespace", "nspname", "namespace", pgcode.UndefinedObject}, } -// queryOidWithJoin looks up the name or OID of an input OID or string in the -// pg_catalog table that the input oid.Oid belongs to. If the input Datum -// is a DOid, the relevant table will be queried by OID; if the input is a -// DString, the table will be queried by its name column. -// -// The return value is a fresh DOid of the input oid.Oid with name and OID -// set to the result of the query. If there was not exactly one result to the -// query, an error will be returned. -func queryOidWithJoin( - ctx *EvalContext, typ *types.T, d Datum, joinClause string, additionalWhere string, -) (*DOid, error) { - ret := &DOid{semanticType: typ} - info := regTypeInfos[typ.Oid()] - var queryCol string - switch d.(type) { - case *DOid: - queryCol = "oid" - case *DString: - queryCol = info.nameCol - default: - return nil, errors.AssertionFailedf("invalid argument to OID cast: %s", d) - } - results, err := ctx.InternalExecutor.QueryRow( - ctx.Ctx(), "queryOidWithJoin", - ctx.Txn, - fmt.Sprintf( - "SELECT %s.oid, %s FROM pg_catalog.%s %s WHERE %s = $1 %s", - info.tableName, info.nameCol, info.tableName, joinClause, queryCol, additionalWhere), - d) - if err != nil { - if errors.HasType(err, (*MultipleResultsError)(nil)) { - return nil, pgerror.Newf(pgcode.AmbiguousAlias, - "more than one %s named %s", info.objName, d) - } - return nil, err - } - if results.Len() == 0 { - return nil, pgerror.Newf(info.errType, "%s %s does not exist", info.objName, d) - } - ret.DInt = results[0].(*DOid).DInt - ret.name = AsStringWithFlags(results[1], FmtBareStrings) - return ret, nil -} - -func queryOid(ctx *EvalContext, typ *types.T, d Datum) (*DOid, error) { - return queryOidWithJoin(ctx, typ, d, "", "") -} - -// Eval implements the TypedExpr interface. -func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - - // NULL cast to anything is NULL. - if d == DNull { - return d, nil - } - d = UnwrapDatum(ctx, d) - return PerformCast(ctx, d, expr.ResolvedType()) -} - -// Eval implements the TypedExpr interface. -func (expr *IndirectionExpr) Eval(ctx *EvalContext) (Datum, error) { - var subscriptIdx int - for i, t := range expr.Indirection { - if t.Slice || i > 0 { - return nil, errors.AssertionFailedf("unsupported feature should have been rejected during planning") - } - - d, err := t.Begin.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if d == DNull { - return d, nil - } - subscriptIdx = int(MustBeDInt(d)) - } - - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if d == DNull { - return d, nil - } - - // Index into the DArray, using 1-indexing. - arr := MustBeDArray(d) - - // VECTOR types use 0-indexing. - switch arr.customOid { - case oid.T_oidvector, oid.T_int2vector: - subscriptIdx++ - } - if subscriptIdx < 1 || subscriptIdx > arr.Len() { - return DNull, nil - } - return arr.Array[subscriptIdx-1], nil -} - -// Eval implements the TypedExpr interface. -func (expr *CollateExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - unwrapped := UnwrapDatum(ctx, d) - if unwrapped == DNull { - return DNull, nil - } - switch d := unwrapped.(type) { - case *DString: - return NewDCollatedString(string(*d), expr.Locale, &ctx.CollationEnv) - case *DCollatedString: - return NewDCollatedString(d.Contents, expr.Locale, &ctx.CollationEnv) - default: - return nil, pgerror.Newf(pgcode.DatatypeMismatch, "incompatible type for COLLATE: %s", d) - } -} - -// Eval implements the TypedExpr interface. -func (expr *ColumnAccessExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - return d.(*DTuple).D[expr.ColIndex], nil -} - -// Eval implements the TypedExpr interface. -func (expr *CoalesceExpr) Eval(ctx *EvalContext) (Datum, error) { - for _, e := range expr.Exprs { - d, err := e.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if d != DNull { - return d, nil - } - } - return DNull, nil -} - -// Eval implements the TypedExpr interface. -// Note: if you're modifying this function, please make sure to adjust -// colexec.comparisonExprAdapter implementations accordingly. -func (expr *ComparisonExpr) Eval(ctx *EvalContext) (Datum, error) { - left, err := expr.Left.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - right, err := expr.Right.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - - op := expr.Operator - if op.HasSubOperator() { - return EvalComparisonExprWithSubOperator(ctx, expr, left, right) - } - - _, newLeft, newRight, _, not := FoldComparisonExpr(op, left, right) - if !expr.Fn.NullableArgs && (newLeft == DNull || newRight == DNull) { - return DNull, nil - } - d, err := expr.Fn.Fn(ctx, newLeft.(Datum), newRight.(Datum)) - if d == DNull || err != nil { - return d, err - } - b, ok := d.(*DBool) - if !ok { - return nil, errors.AssertionFailedf("%v is %T and not *DBool", d, d) - } - return MakeDBool(*b != DBool(not)), nil -} - -// EvalComparisonExprWithSubOperator evaluates a comparison expression that has -// sub-operator. -func EvalComparisonExprWithSubOperator( - ctx *EvalContext, expr *ComparisonExpr, left, right Datum, -) (Datum, error) { - var datums Datums - // Right is either a tuple or an array of Datums. - if !expr.Fn.NullableArgs && right == DNull { - return DNull, nil - } else if tuple, ok := AsDTuple(right); ok { - datums = tuple.D - } else if array, ok := AsDArray(right); ok { - datums = array.Array - } else { - return nil, errors.AssertionFailedf("unhandled right expression %s", right) - } - return evalDatumsCmp(ctx, expr.Operator, expr.SubOperator, expr.Fn, left, datums) -} - -// EvalArgsAndGetGenerator evaluates the arguments and instanciates a -// ValueGenerator for use by set projections. -func (expr *FuncExpr) EvalArgsAndGetGenerator(ctx *EvalContext) (ValueGenerator, error) { - if expr.fn == nil || expr.fnProps.Class != GeneratorClass { - return nil, errors.AssertionFailedf("cannot call EvalArgsAndGetGenerator() on non-aggregate function: %q", ErrString(expr)) - } - nullArg, args, err := expr.evalArgs(ctx) - if err != nil || nullArg { - return nil, err - } - return expr.fn.Generator(ctx, args) -} - -// evalArgs evaluates just the function application's arguments. -// The returned bool indicates that the NULL should be propagated. -func (expr *FuncExpr) evalArgs(ctx *EvalContext) (bool, Datums, error) { - args := make(Datums, len(expr.Exprs)) - for i, e := range expr.Exprs { - arg, err := e.(TypedExpr).Eval(ctx) - if err != nil { - return false, nil, err - } - if arg == DNull && !expr.fnProps.NullableArgs { - return true, nil, nil - } - args[i] = arg - } - return false, args, nil -} - -// Eval implements the TypedExpr interface. -func (expr *FuncExpr) Eval(ctx *EvalContext) (Datum, error) { - nullResult, args, err := expr.evalArgs(ctx) - if err != nil { - return nil, err - } - if nullResult { - return DNull, err - } - - res, err := expr.fn.Fn(ctx, args) - if err != nil { - // If we are facing an explicit error, propagate it unchanged. - fName := expr.Func.String() - if fName == `crdb_internal.force_error` { - return nil, err - } - // Otherwise, wrap it with context. - newErr := errors.Wrapf(err, "%s()", errors.Safe(fName)) - return nil, newErr - } - if ctx.TestingKnobs.AssertFuncExprReturnTypes { - if err := ensureExpectedType(expr.fn.FixedReturnType(), res); err != nil { - return nil, errors.NewAssertionErrorWithWrappedErrf(err, "function %q", expr) - } - } - return res, nil -} - -// ensureExpectedType will return an error if a datum does not match the -// provided type. If the expected type is Any or if the datum is a Null -// type, then no error will be returned. -func ensureExpectedType(exp *types.T, d Datum) error { - if !(exp.Family() == types.AnyFamily || d.ResolvedType().Family() == types.UnknownFamily || - d.ResolvedType().Equivalent(exp)) { - return errors.AssertionFailedf( - "expected return type %q, got: %q", errors.Safe(exp), errors.Safe(d.ResolvedType())) - } - return nil -} - -// Eval implements the TypedExpr interface. -func (expr *IfErrExpr) Eval(ctx *EvalContext) (Datum, error) { - cond, evalErr := expr.Cond.(TypedExpr).Eval(ctx) - if evalErr == nil { - if expr.Else == nil { - return DBoolFalse, nil - } - return cond, nil - } - if expr.ErrCode != nil { - errpat, err := expr.ErrCode.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if errpat == DNull { - return nil, evalErr - } - errpatStr := string(MustBeDString(errpat)) - if code := pgerror.GetPGCode(evalErr); code != pgcode.MakeCode(errpatStr) { - return nil, evalErr - } - } - if expr.Else == nil { - return DBoolTrue, nil - } - return expr.Else.(TypedExpr).Eval(ctx) -} - -// Eval implements the TypedExpr interface. -func (expr *IfExpr) Eval(ctx *EvalContext) (Datum, error) { - cond, err := expr.Cond.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if cond == DBoolTrue { - return expr.True.(TypedExpr).Eval(ctx) - } - return expr.Else.(TypedExpr).Eval(ctx) -} - -// Eval implements the TypedExpr interface. -func (expr *IsOfTypeExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - datumTyp := d.ResolvedType() - - for _, t := range expr.ResolvedTypes() { - if datumTyp.Equivalent(t) { - return MakeDBool(DBool(!expr.Not)), nil - } - } - return MakeDBool(DBool(expr.Not)), nil -} - -// Eval implements the TypedExpr interface. -func (expr *NotExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if d == DNull { - return DNull, nil - } - v, err := GetBool(d) - if err != nil { - return nil, err - } - return MakeDBool(!v), nil -} - -// Eval implements the TypedExpr interface. -func (expr *IsNullExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if d == DNull { - return MakeDBool(true), nil - } - if t, ok := d.(*DTuple); ok { - // A tuple IS NULL if all elements are NULL. - for _, tupleDatum := range t.D { - if tupleDatum != DNull { - return MakeDBool(false), nil - } - } - return MakeDBool(true), nil - } - return MakeDBool(false), nil -} - -// Eval implements the TypedExpr interface. -func (expr *IsNotNullExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if d == DNull { - return MakeDBool(false), nil - } - if t, ok := d.(*DTuple); ok { - // A tuple IS NOT NULL if all elements are not NULL. - for _, tupleDatum := range t.D { - if tupleDatum == DNull { - return MakeDBool(false), nil - } - } - return MakeDBool(true), nil - } - return MakeDBool(true), nil -} - -// Eval implements the TypedExpr interface. -func (expr *NullIfExpr) Eval(ctx *EvalContext) (Datum, error) { - expr1, err := expr.Expr1.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - expr2, err := expr.Expr2.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - cond, err := evalComparison(ctx, EQ, expr1, expr2) - if err != nil { - return nil, err - } - if cond == DBoolTrue { - return DNull, nil - } - return expr1, nil -} - -// Eval implements the TypedExpr interface. -func (expr *OrExpr) Eval(ctx *EvalContext) (Datum, error) { - left, err := expr.Left.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if left != DNull { - if v, err := GetBool(left); err != nil { - return nil, err - } else if v { - return left, nil - } - } - right, err := expr.Right.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if right == DNull { - return DNull, nil - } - if v, err := GetBool(right); err != nil { - return nil, err - } else if v { - return right, nil - } - if left == DNull { - return DNull, nil - } - return DBoolFalse, nil -} - -// Eval implements the TypedExpr interface. -func (expr *ParenExpr) Eval(ctx *EvalContext) (Datum, error) { - return expr.Expr.(TypedExpr).Eval(ctx) -} - -// Eval implements the TypedExpr interface. -func (expr *RangeCond) Eval(ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unhandled type %T", expr) -} - -// Eval implements the TypedExpr interface. -func (expr *UnaryExpr) Eval(ctx *EvalContext) (Datum, error) { - d, err := expr.Expr.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if d == DNull { - return DNull, nil - } - res, err := expr.fn.Fn(ctx, d) - if err != nil { - return nil, err - } - if ctx.TestingKnobs.AssertUnaryExprReturnTypes { - if err := ensureExpectedType(expr.fn.ReturnType, res); err != nil { - return nil, errors.NewAssertionErrorWithWrappedErrf(err, "unary op %q", expr) - } - } - return res, err -} - -// Eval implements the TypedExpr interface. -func (expr DefaultVal) Eval(ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unhandled type %T", expr) -} - -// Eval implements the TypedExpr interface. -func (expr UnqualifiedStar) Eval(ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unhandled type %T", expr) -} - -// Eval implements the TypedExpr interface. -func (expr *UnresolvedName) Eval(ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unhandled type %T", expr) -} - -// Eval implements the TypedExpr interface. -func (expr *AllColumnsSelector) Eval(ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unhandled type %T", expr) -} - -// Eval implements the TypedExpr interface. -func (expr *TupleStar) Eval(ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unhandled type %T", expr) -} - -// Eval implements the TypedExpr interface. -func (expr *ColumnItem) Eval(ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unhandled type %T", expr) -} - -// Eval implements the TypedExpr interface. -func (t *Tuple) Eval(ctx *EvalContext) (Datum, error) { - tuple := NewDTupleWithLen(t.typ, len(t.Exprs)) - for i, v := range t.Exprs { - d, err := v.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - tuple.D[i] = d - } - return tuple, nil -} - -// arrayOfType returns a fresh DArray of the input type. -func arrayOfType(typ *types.T) (*DArray, error) { - if typ.Family() != types.ArrayFamily { - return nil, errors.AssertionFailedf("array node type (%v) is not types.TArray", typ) - } - if err := types.CheckArrayElementType(typ.ArrayContents()); err != nil { - return nil, err - } - return NewDArray(typ.ArrayContents()), nil -} - -// Eval implements the TypedExpr interface. -func (t *Array) Eval(ctx *EvalContext) (Datum, error) { - array, err := arrayOfType(t.ResolvedType()) - if err != nil { - return nil, err - } - - for _, v := range t.Exprs { - d, err := v.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - if err := array.Append(d); err != nil { - return nil, err - } - } - return array, nil -} - -// Eval implements the TypedExpr interface. -func (expr *Subquery) Eval(ctx *EvalContext) (Datum, error) { - return ctx.Planner.EvalSubquery(expr) -} - -// Eval implements the TypedExpr interface. -func (t *ArrayFlatten) Eval(ctx *EvalContext) (Datum, error) { - array, err := arrayOfType(t.ResolvedType()) - if err != nil { - return nil, err - } - - d, err := t.Subquery.(TypedExpr).Eval(ctx) - if err != nil { - return nil, err - } - - tuple, ok := d.(*DTuple) - if !ok { - return nil, errors.AssertionFailedf("array subquery result (%v) is not DTuple", d) - } - array.Array = tuple.D - return array, nil -} - -// Eval implements the TypedExpr interface. -func (t *DBitArray) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DBool) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DBytes) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DUuid) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DIPAddr) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DDate) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DTime) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DTimeTZ) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DFloat) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DDecimal) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DInt) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DInterval) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DBox2D) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DGeography) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DGeometry) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DEnum) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DJSON) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t dNull) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DString) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DCollatedString) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DTimestamp) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DTimestampTZ) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DTuple) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DArray) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DOid) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *DOidWrapper) Eval(_ *EvalContext) (Datum, error) { - return t, nil -} - -// Eval implements the TypedExpr interface. -func (t *Placeholder) Eval(ctx *EvalContext) (Datum, error) { - if !ctx.HasPlaceholders() { - // While preparing a query, there will be no available placeholders. A - // placeholder evaluates to itself at this point. - return t, nil - } - e, ok := ctx.Placeholders.Value(t.Idx) - if !ok { - return nil, pgerror.Newf(pgcode.UndefinedParameter, - "no value provided for placeholder: %s", t) - } - // Placeholder expressions cannot contain other placeholders, so we do - // not need to recurse. - typ := ctx.Placeholders.Types[t.Idx] - if typ == nil { - // All placeholders should be typed at this point. - return nil, errors.AssertionFailedf("missing type for placeholder %s", t) - } - if !e.ResolvedType().Equivalent(typ) { - // This happens when we overrode the placeholder's type during type - // checking, since the placeholder's type hint didn't match the desired - // type for the placeholder. In this case, we cast the expression to - // the desired type. - // TODO(jordan): introduce a restriction on what casts are allowed here. - cast := NewTypedCastExpr(e, typ) - return cast.Eval(ctx) - } - return e.Eval(ctx) -} - -func evalComparison(ctx *EvalContext, op ComparisonOperator, left, right Datum) (Datum, error) { - if left == DNull || right == DNull { - return DNull, nil - } - ltype := left.ResolvedType() - rtype := right.ResolvedType() - if fn, ok := CmpOps[op].LookupImpl(ltype, rtype); ok { - return fn.Fn(ctx, left, right) - } - return nil, pgerror.Newf( - pgcode.UndefinedFunction, "unsupported comparison operator: <%s> %s <%s>", ltype, op, rtype) -} - -// FoldComparisonExpr folds a given comparison operation and its expressions -// into an equivalent operation that will hit in the CmpOps map, returning -// this new operation, along with potentially flipped operands and "flipped" -// and "not" flags. -func FoldComparisonExpr( - op ComparisonOperator, left, right Expr, -) (newOp ComparisonOperator, newLeft Expr, newRight Expr, flipped bool, not bool) { - switch op { - case NE: - // NE(left, right) is implemented as !EQ(left, right). - return EQ, left, right, false, true - case GT: - // GT(left, right) is implemented as LT(right, left) - return LT, right, left, true, false - case GE: - // GE(left, right) is implemented as LE(right, left) - return LE, right, left, true, false - case NotIn: - // NotIn(left, right) is implemented as !IN(left, right) - return In, left, right, false, true - case NotLike: - // NotLike(left, right) is implemented as !Like(left, right) - return Like, left, right, false, true - case NotILike: - // NotILike(left, right) is implemented as !ILike(left, right) - return ILike, left, right, false, true - case NotSimilarTo: - // NotSimilarTo(left, right) is implemented as !SimilarTo(left, right) - return SimilarTo, left, right, false, true - case NotRegMatch: - // NotRegMatch(left, right) is implemented as !RegMatch(left, right) - return RegMatch, left, right, false, true - case NotRegIMatch: - // NotRegIMatch(left, right) is implemented as !RegIMatch(left, right) - return RegIMatch, left, right, false, true - case IsDistinctFrom: - // IsDistinctFrom(left, right) is implemented as !IsNotDistinctFrom(left, right) - // Note: this seems backwards, but IS NOT DISTINCT FROM is an extended - // version of IS and IS DISTINCT FROM is an extended version of IS NOT. - return IsNotDistinctFrom, left, right, false, true +// FoldComparisonExpr folds a given comparison operation and its expressions +// into an equivalent operation that will hit in the CmpOps map, returning +// this new operation, along with potentially flipped operands and "flipped" +// and "not" flags. +func FoldComparisonExpr( + op ComparisonOperator, left, right Expr, +) (newOp ComparisonOperator, newLeft Expr, newRight Expr, flipped bool, not bool) { + switch op { + case NE: + // NE(left, right) is implemented as !EQ(left, right). + return EQ, left, right, false, true + case GT: + // GT(left, right) is implemented as LT(right, left) + return LT, right, left, true, false + case GE: + // GE(left, right) is implemented as LE(right, left) + return LE, right, left, true, false + case NotIn: + // NotIn(left, right) is implemented as !IN(left, right) + return In, left, right, false, true + case NotLike: + // NotLike(left, right) is implemented as !Like(left, right) + return Like, left, right, false, true + case NotILike: + // NotILike(left, right) is implemented as !ILike(left, right) + return ILike, left, right, false, true + case NotSimilarTo: + // NotSimilarTo(left, right) is implemented as !SimilarTo(left, right) + return SimilarTo, left, right, false, true + case NotRegMatch: + // NotRegMatch(left, right) is implemented as !RegMatch(left, right) + return RegMatch, left, right, false, true + case NotRegIMatch: + // NotRegIMatch(left, right) is implemented as !RegIMatch(left, right) + return RegIMatch, left, right, false, true + case IsDistinctFrom: + // IsDistinctFrom(left, right) is implemented as !IsNotDistinctFrom(left, right) + // Note: this seems backwards, but IS NOT DISTINCT FROM is an extended + // version of IS and IS DISTINCT FROM is an extended version of IS NOT. + return IsNotDistinctFrom, left, right, false, true } return op, left, right, false, false } - -// hasUnescapedSuffix returns true if the ending byte is suffix and s has an -// even number of escapeTokens preceding suffix. Otherwise hasUnescapedSuffix -// will return false. -func hasUnescapedSuffix(s string, suffix byte, escapeToken string) bool { - if s[len(s)-1] == suffix { - var count int - idx := len(s) - len(escapeToken) - 1 - for idx >= 0 && s[idx:idx+len(escapeToken)] == escapeToken { - count++ - idx -= len(escapeToken) - } - return count%2 == 0 - } - return false -} - -// Simplifies LIKE/ILIKE expressions that do not need full regular expressions to -// evaluate the condition. For example, when the expression is just checking to see -// if a string starts with a given pattern. -func optimizedLikeFunc( - pattern string, caseInsensitive bool, escape rune, -) (func(string) (bool, error), error) { - switch len(pattern) { - case 0: - return func(s string) (bool, error) { - return s == "", nil - }, nil - case 1: - switch pattern[0] { - case '%': - if escape == '%' { - return nil, pgerror.Newf(pgcode.InvalidEscapeSequence, "LIKE pattern must not end with escape character") - } - return func(s string) (bool, error) { - return true, nil - }, nil - case '_': - if escape == '_' { - return nil, pgerror.Newf(pgcode.InvalidEscapeSequence, "LIKE pattern must not end with escape character") - } - return func(s string) (bool, error) { - if len(s) == 0 { - return false, nil - } - firstChar, _ := utf8.DecodeRuneInString(s) - if firstChar == utf8.RuneError { - return false, errors.Errorf("invalid encoding of the first character in string %s", s) - } - return len(s) == len(string(firstChar)), nil - }, nil - } - default: - if !strings.ContainsAny(pattern[1:len(pattern)-1], "_%") { - // Patterns with even number of escape characters preceding the ending - // `%` will have anyEnd set to true (if `%` itself is not an escape - // character). Otherwise anyEnd will be set to false. - anyEnd := hasUnescapedSuffix(pattern, '%', string(escape)) && escape != '%' - // If '%' is the escape character, then it's not a wildcard. - anyStart := pattern[0] == '%' && escape != '%' - - // Patterns with even number of escape characters preceding the ending - // `_` will have singleAnyEnd set to true (if `_` itself is not an escape - // character). Otherwise singleAnyEnd will be set to false. - singleAnyEnd := hasUnescapedSuffix(pattern, '_', string(escape)) && escape != '_' - // If '_' is the escape character, then it's not a wildcard. - singleAnyStart := pattern[0] == '_' && escape != '_' - - // Since we've already checked for escaped characters - // at the end, we can un-escape every character. - // This is required since we do direct string - // comparison. - var err error - if pattern, err = unescapePattern(pattern, string(escape), true /* emitEscapeCharacterLastError */); err != nil { - return nil, err - } - switch { - case anyEnd && anyStart: - return func(s string) (bool, error) { - substr := pattern[1 : len(pattern)-1] - if caseInsensitive { - s, substr = strings.ToUpper(s), strings.ToUpper(substr) - } - return strings.Contains(s, substr), nil - }, nil - - case anyEnd: - return func(s string) (bool, error) { - prefix := pattern[:len(pattern)-1] - if singleAnyStart { - if len(s) == 0 { - return false, nil - } - prefix = prefix[1:] - firstChar, _ := utf8.DecodeRuneInString(s) - if firstChar == utf8.RuneError { - return false, errors.Errorf("invalid encoding of the first character in string %s", s) - } - s = s[len(string(firstChar)):] - } - if caseInsensitive { - s, prefix = strings.ToUpper(s), strings.ToUpper(prefix) - } - return strings.HasPrefix(s, prefix), nil - }, nil - - case anyStart: - return func(s string) (bool, error) { - suffix := pattern[1:] - if singleAnyEnd { - if len(s) == 0 { - return false, nil - } - - suffix = suffix[:len(suffix)-1] - lastChar, _ := utf8.DecodeLastRuneInString(s) - if lastChar == utf8.RuneError { - return false, errors.Errorf("invalid encoding of the last character in string %s", s) - } - s = s[:len(s)-len(string(lastChar))] - } - if caseInsensitive { - s, suffix = strings.ToUpper(s), strings.ToUpper(suffix) - } - return strings.HasSuffix(s, suffix), nil - }, nil - - case singleAnyStart || singleAnyEnd: - return func(s string) (bool, error) { - if len(s) < 1 { - return false, nil - } - firstChar, _ := utf8.DecodeRuneInString(s) - if firstChar == utf8.RuneError { - return false, errors.Errorf("invalid encoding of the first character in string %s", s) - } - lastChar, _ := utf8.DecodeLastRuneInString(s) - if lastChar == utf8.RuneError { - return false, errors.Errorf("invalid encoding of the last character in string %s", s) - } - if singleAnyStart && singleAnyEnd && len(string(firstChar))+len(string(lastChar)) > len(s) { - return false, nil - } - - if singleAnyStart { - pattern = pattern[1:] - s = s[len(string(firstChar)):] - } - - if singleAnyEnd { - pattern = pattern[:len(pattern)-1] - s = s[:len(s)-len(string(lastChar))] - } - - if caseInsensitive { - s, pattern = strings.ToUpper(s), strings.ToUpper(pattern) - } - - // We don't have to check for - // prefixes/suffixes since we do not - // have '%': - // - singleAnyEnd && anyStart handled - // in case anyStart - // - singleAnyStart && anyEnd handled - // in case anyEnd - return s == pattern, nil - }, nil - } - } - } - return nil, nil -} - -type likeKey struct { - s string - caseInsensitive bool - escape rune -} - -// unescapePattern unescapes a pattern for a given escape token. -// It handles escaped escape tokens properly by maintaining them as the escape -// token in the return string. -// For example, suppose we have escape token `\` (e.g. `B` is escaped in -// `A\BC` and `\` is escaped in `A\\C`). -// We need to convert -// `\` --> `` -// `\\` --> `\` -// We cannot simply use strings.Replace for each conversion since the first -// conversion will incorrectly replace our escaped escape token `\\` with ``. -// Another example is if our escape token is `\\` (e.g. after -// regexp.QuoteMeta). -// We need to convert -// `\\` --> `` -// `\\\\` --> `\\` -func unescapePattern( - pattern, escapeToken string, emitEscapeCharacterLastError bool, -) (string, error) { - escapedEscapeToken := escapeToken + escapeToken - - // We need to subtract the escaped escape tokens to avoid double - // counting. - nEscapes := strings.Count(pattern, escapeToken) - strings.Count(pattern, escapedEscapeToken) - if nEscapes == 0 { - return pattern, nil - } - - // Allocate buffer for final un-escaped pattern. - ret := make([]byte, len(pattern)-nEscapes*len(escapeToken)) - retWidth := 0 - for i := 0; i < nEscapes; i++ { - nextIdx := strings.Index(pattern, escapeToken) - if nextIdx == len(pattern)-len(escapeToken) && emitEscapeCharacterLastError { - return "", pgerror.Newf(pgcode.InvalidEscapeSequence, `LIKE pattern must not end with escape character`) - } - - retWidth += copy(ret[retWidth:], pattern[:nextIdx]) - - if nextIdx < len(pattern)-len(escapedEscapeToken) && pattern[nextIdx:nextIdx+len(escapedEscapeToken)] == escapedEscapeToken { - // We have an escaped escape token. - // We want to keep it as the original escape token in - // the return string. - retWidth += copy(ret[retWidth:], escapeToken) - pattern = pattern[nextIdx+len(escapedEscapeToken):] - continue - } - - // Skip over the escape character we removed. - pattern = pattern[nextIdx+len(escapeToken):] - } - - retWidth += copy(ret[retWidth:], pattern) - return string(ret[0:retWidth]), nil -} - -// replaceUnescaped replaces all instances of oldStr that are not escaped (read: -// preceded) with the specified unescape token with newStr. -// For example, with an escape token of `\\` -// replaceUnescaped("TE\\__ST", "_", ".", `\\`) --> "TE\\_.ST" -// replaceUnescaped("TE\\%%ST", "%", ".*", `\\`) --> "TE\\%.*ST" -// If the preceding escape token is escaped, then oldStr will be replaced. -// For example -// replaceUnescaped("TE\\\\_ST", "_", ".", `\\`) --> "TE\\\\.ST" -func replaceUnescaped(s, oldStr, newStr string, escapeToken string) string { - // We count the number of occurrences of 'oldStr'. - // This however can be an overestimate since the oldStr token could be - // escaped. e.g. `\\_`. - nOld := strings.Count(s, oldStr) - if nOld == 0 { - return s - } - - // Allocate buffer for final string. - // This can be an overestimate since some of the oldStr tokens may - // be escaped. - // This is fine since we keep track of the running number of bytes - // actually copied. - // It's rather difficult to count the exact number of unescaped - // tokens without manually iterating through the entire string and - // keeping track of escaped escape tokens. - retLen := len(s) - // If len(newStr) - len(oldStr) < 0, then this can under-allocate which - // will not behave correctly with copy. - if addnBytes := nOld * (len(newStr) - len(oldStr)); addnBytes > 0 { - retLen += addnBytes - } - ret := make([]byte, retLen) - retWidth := 0 - start := 0 -OldLoop: - for i := 0; i < nOld; i++ { - nextIdx := start + strings.Index(s[start:], oldStr) - - escaped := false - for { - // We need to look behind to check if the escape token - // is really an escape token. - // E.g. if our specified escape token is `\\` and oldStr - // is `_`, then - // `\\_` --> escaped - // `\\\\_` --> not escaped - // `\\\\\\_` --> escaped - curIdx := nextIdx - lookbehindIdx := curIdx - len(escapeToken) - for lookbehindIdx >= 0 && s[lookbehindIdx:curIdx] == escapeToken { - escaped = !escaped - curIdx = lookbehindIdx - lookbehindIdx = curIdx - len(escapeToken) - } - - // The token was not be escaped. Proceed. - if !escaped { - break - } - - // Token was escaped. Copy everything over and continue. - retWidth += copy(ret[retWidth:], s[start:nextIdx+len(oldStr)]) - start = nextIdx + len(oldStr) - - // Continue with next oldStr token. - continue OldLoop - } - - // Token was not escaped so we replace it with newStr. - // Two copies is more efficient than concatenating the slices. - retWidth += copy(ret[retWidth:], s[start:nextIdx]) - retWidth += copy(ret[retWidth:], newStr) - start = nextIdx + len(oldStr) - } - - retWidth += copy(ret[retWidth:], s[start:]) - return string(ret[0:retWidth]) -} - -// Replaces all custom escape characters in s with `\\` only when they are unescaped. (1) -// E.g. original pattern after QuoteMeta after replaceCustomEscape with '@' as escape -// '@w@w' -> '@w@w' -> '\\w\\w' -// '@\@\' -> '@\\@\\' -> '\\\\\\\\' -// -// When an escape character is escaped, we replace it with its single occurrence. (2) -// E.g. original pattern after QuoteMeta after replaceCustomEscape with '@' as escape -// '@@w@w' -> '@@w@w' -> '@w\\w' -// '@@@\' -> '@@@\\' -> '@\\\\' -// -// At the same time, we do not want to confuse original backslashes (which -// after QuoteMeta are '\\') with backslashes that replace our custom escape characters, -// so we escape these original backslashes again by converting '\\' into '\\\\'. (3) -// E.g. original pattern after QuoteMeta after replaceCustomEscape with '@' as escape -// '@\' -> '@\\' -> '\\\\\\' -// '@\@@@\' -> '@\\@@@\\' -> '\\\\\\@\\\\\\' -// -// Explanation of the last example: -// 1. we replace '@' with '\\' since it's unescaped; -// 2. we escape single original backslash ('\' is not our escape character, so we want -// the pattern to understand it) by putting an extra backslash in front of it. However, -// we later will call unescapePattern, so we need to double our double backslashes. -// Therefore, '\\' is converted into '\\\\'. -// 3. '@@' is replaced by '@' because it is escaped escape character. -// 4. '@' is replaced with '\\' since it's unescaped. -// 5. Similar logic to step 2: '\\' -> '\\\\'. -// -// We always need to keep in mind that later call of unescapePattern -// to actually unescape '\\' and '\\\\' is necessary and that -// escape must be a single unicode character and not `\`. -func replaceCustomEscape(s string, escape rune) (string, error) { - changed, retLen, err := calculateLengthAfterReplacingCustomEscape(s, escape) - if err != nil { - return "", err - } - if !changed { - return s, nil - } - - sLen := len(s) - ret := make([]byte, retLen) - retIndex, sIndex := 0, 0 - for retIndex < retLen { - sRune, w := utf8.DecodeRuneInString(s[sIndex:]) - if sRune == escape { - // We encountered an escape character. - if sIndex+w < sLen { - // Escape character is not the last character in s, so we need - // to look ahead to figure out how to process it. - tRune, _ := utf8.DecodeRuneInString(s[(sIndex + w):]) - if tRune == escape { - // Escape character is escaped, so we replace its two occurrences with just one. See (2). - // We copied only one escape character to ret, so we advance retIndex only by w. - // Since we've already processed two characters in s, we advance sIndex by 2*w. - utf8.EncodeRune(ret[retIndex:], escape) - retIndex += w - sIndex += 2 * w - } else { - // Escape character is unescaped, so we replace it with `\\`. See (1). - // Since we've added two bytes to ret, we advance retIndex by 2. - // We processed only a single escape character in s, we advance sIndex by w. - ret[retIndex] = '\\' - ret[retIndex+1] = '\\' - retIndex += 2 - sIndex += w - } - } else { - // Escape character is the last character in s which is an error - // that must have been caught in calculateLengthAfterReplacingCustomEscape. - return "", errors.AssertionFailedf( - "unexpected: escape character is the last one in replaceCustomEscape.") - } - } else if s[sIndex] == '\\' { - // We encountered a backslash, so we need to look ahead to figure out how - // to process it. - if sIndex+1 == sLen { - // This case should never be reached since it should - // have been caught in calculateLengthAfterReplacingCustomEscape. - return "", errors.AssertionFailedf( - "unexpected: a single backslash encountered in replaceCustomEscape.") - } else if s[sIndex+1] == '\\' { - // We want to escape '\\' to `\\\\` for correct processing later by unescapePattern. See (3). - // Since we've added four characters to ret, we advance retIndex by 4. - // Since we've already processed two characters in s, we advance sIndex by 2. - ret[retIndex] = '\\' - ret[retIndex+1] = '\\' - ret[retIndex+2] = '\\' - ret[retIndex+3] = '\\' - retIndex += 4 - sIndex += 2 - } else { - // A metacharacter other than a backslash is escaped here. - // Note: all metacharacters are encoded as a single byte, so it is - // correct to just convert it to string and to compare against a char - // in s. - if string(s[sIndex+1]) == string(escape) { - // The metacharacter is our custom escape character. We need to look - // ahead to process it. - if sIndex+2 == sLen { - // Escape character is the last character in s which is an error - // that must have been caught in calculateLengthAfterReplacingCustomEscape. - return "", errors.AssertionFailedf( - "unexpected: escape character is the last one in replaceCustomEscape.") - } - if sIndex+4 <= sLen { - if s[sIndex+2] == '\\' && string(s[sIndex+3]) == string(escape) { - // We have a sequence of `\`+escape+`\`+escape which is replaced - // by `\`+escape. - ret[retIndex] = '\\' - // Note: all metacharacters are encoded as a single byte, so it - // is safe to just convert it to string and take the first - // character. - ret[retIndex+1] = string(escape)[0] - retIndex += 2 - sIndex += 4 - continue - } - } - // The metacharacter is escaping something different than itself, so - // `\`+escape will be replaced by `\`. - ret[retIndex] = '\\' - retIndex++ - sIndex += 2 - } else { - // The metacharacter is not our custom escape character, so we're - // simply copying the backslash and the metacharacter. - ret[retIndex] = '\\' - ret[retIndex+1] = s[sIndex+1] - retIndex += 2 - sIndex += 2 - } - } - } else { - // Regular symbol, so we simply copy it. - copy(ret[retIndex:], s[sIndex:sIndex+w]) - retIndex += w - sIndex += w - } - } - return string(ret), nil -} - -// calculateLengthAfterReplacingCustomEscape returns whether the pattern changes, the length -// of the resulting pattern after calling replaceCustomEscape, and any error if found. -func calculateLengthAfterReplacingCustomEscape(s string, escape rune) (bool, int, error) { - changed := false - retLen, sLen := 0, len(s) - for i := 0; i < sLen; { - sRune, w := utf8.DecodeRuneInString(s[i:]) - if sRune == escape { - // We encountered an escape character. - if i+w < sLen { - // Escape character is not the last character in s, so we need - // to look ahead to figure out how to process it. - tRune, _ := utf8.DecodeRuneInString(s[(i + w):]) - if tRune == escape { - // Escape character is escaped, so we'll replace its two occurrences with just one. - // See (2) in the comment above replaceCustomEscape. - changed = true - retLen += w - i += 2 * w - } else { - // Escape character is unescaped, so we'll replace it with `\\`. - // See (1) in the comment above replaceCustomEscape. - changed = true - retLen += 2 - i += w - } - } else { - // Escape character is the last character in s, so we need to return an error. - return false, 0, pgerror.Newf(pgcode.InvalidEscapeSequence, "LIKE pattern must not end with escape character") - } - } else if s[i] == '\\' { - // We encountered a backslash, so we need to look ahead to figure out how - // to process it. - if i+1 == sLen { - // This case should never be reached because the backslash should be - // escaping one of regexp metacharacters. - return false, 0, pgerror.Newf(pgcode.InvalidEscapeSequence, "Unexpected behavior during processing custom escape character.") - } else if s[i+1] == '\\' { - // We'll want to escape '\\' to `\\\\` for correct processing later by - // unescapePattern. See (3) in the comment above replaceCustomEscape. - changed = true - retLen += 4 - i += 2 - } else { - // A metacharacter other than a backslash is escaped here. - if string(s[i+1]) == string(escape) { - // The metacharacter is our custom escape character. We need to look - // ahead to process it. - if i+2 == sLen { - // Escape character is the last character in s, so we need to return an error. - return false, 0, pgerror.Newf(pgcode.InvalidEscapeSequence, "LIKE pattern must not end with escape character") - } - if i+4 <= sLen { - if s[i+2] == '\\' && string(s[i+3]) == string(escape) { - // We have a sequence of `\`+escape+`\`+escape which will be - // replaced by `\`+escape. - changed = true - retLen += 2 - i += 4 - continue - } - } - // The metacharacter is escaping something different than itself, so - // `\`+escape will be replaced by `\`. - changed = true - retLen++ - i += 2 - } else { - // The metacharacter is not our custom escape character, so we're - // simply copying the backslash and the metacharacter. - retLen += 2 - i += 2 - } - } - } else { - // Regular symbol, so we'll simply copy it. - retLen += w - i += w - } - } - return changed, retLen, nil -} - -// Pattern implements the RegexpCacheKey interface. -// The strategy for handling custom escape character -// is to convert all unescaped escape character into '\'. -// k.escape can either be empty or a single character. -func (k likeKey) Pattern() (string, error) { - // QuoteMeta escapes all regexp metacharacters (`\.+*?()|[]{}^$`) with a `\`. - pattern := regexp.QuoteMeta(k.s) - var err error - if k.escape == 0 { - // Replace all LIKE/ILIKE specific wildcards with standard wildcards - // (escape character is empty - escape mechanism is turned off - so - // all '%' and '_' are actual wildcards regardless of what precedes them. - pattern = strings.Replace(pattern, `%`, `.*`, -1) - pattern = strings.Replace(pattern, `_`, `.`, -1) - } else if k.escape == '\\' { - // Replace LIKE/ILIKE specific wildcards with standard wildcards only when - // these wildcards are not escaped by '\\' (equivalent of '\' after QuoteMeta). - pattern = replaceUnescaped(pattern, `%`, `.*`, `\\`) - pattern = replaceUnescaped(pattern, `_`, `.`, `\\`) - } else { - // k.escape is non-empty and not `\`. - // If `%` is escape character, then it's not a wildcard. - if k.escape != '%' { - // Replace LIKE/ILIKE specific wildcards '%' only if it's unescaped. - if k.escape == '.' { - // '.' is the escape character, so for correct processing later by - // replaceCustomEscape we need to escape it by itself. - pattern = replaceUnescaped(pattern, `%`, `..*`, regexp.QuoteMeta(string(k.escape))) - } else if k.escape == '*' { - // '*' is the escape character, so for correct processing later by - // replaceCustomEscape we need to escape it by itself. - pattern = replaceUnescaped(pattern, `%`, `.**`, regexp.QuoteMeta(string(k.escape))) - } else { - pattern = replaceUnescaped(pattern, `%`, `.*`, regexp.QuoteMeta(string(k.escape))) - } - } - // If `_` is escape character, then it's not a wildcard. - if k.escape != '_' { - // Replace LIKE/ILIKE specific wildcards '_' only if it's unescaped. - if k.escape == '.' { - // '.' is the escape character, so for correct processing later by - // replaceCustomEscape we need to escape it by itself. - pattern = replaceUnescaped(pattern, `_`, `..`, regexp.QuoteMeta(string(k.escape))) - } else { - pattern = replaceUnescaped(pattern, `_`, `.`, regexp.QuoteMeta(string(k.escape))) - } - } - - // If a sequence of symbols ` escape+`\\` ` is unescaped, then that escape - // character escapes backslash in the original pattern (we need to use double - // backslash because of QuoteMeta behavior), so we want to "consume" the escape character. - pattern = replaceUnescaped(pattern, string(k.escape)+`\\`, `\\`, regexp.QuoteMeta(string(k.escape))) - - // We want to replace all escape characters with `\\` only - // when they are unescaped. When an escape character is escaped, - // we replace it with its single occurrence. - if pattern, err = replaceCustomEscape(pattern, k.escape); err != nil { - return pattern, err - } - } - - // After QuoteMeta, all '\' were converted to '\\'. - // After replaceCustomEscape, our custom unescaped escape characters were converted to `\\`, - // so now our pattern contains only '\\' as escape tokens. - // We need to unescape escaped escape tokens `\\` (now `\\\\`) and - // other escaped characters `\A` (now `\\A`). - if k.escape != 0 { - // We do not want to return an error when pattern ends with the supposed escape character `\` - // whereas the actual escape character is not `\`. The case when the pattern ends with - // an actual escape character is handled in replaceCustomEscape. For example, with '-' as - // the escape character on pattern 'abc\\' we do not want to return an error 'pattern ends - // with escape character' because '\\' is not an escape character in this case. - if pattern, err = unescapePattern( - pattern, - `\\`, - k.escape == '\\', /* emitEscapeCharacterLastError */ - ); err != nil { - return "", err - } - } - - return anchorPattern(pattern, k.caseInsensitive), nil -} - -type similarToKey struct { - s string - escape rune -} - -// Pattern implements the RegexpCacheKey interface. -func (k similarToKey) Pattern() (string, error) { - pattern := similarEscapeCustomChar(k.s, k.escape, k.escape != 0) - return anchorPattern(pattern, false), nil -} - -// SimilarToEscape checks if 'unescaped' is SIMILAR TO 'pattern' using custom escape token 'escape' -// which must be either empty (which disables the escape mechanism) or a single unicode character. -func SimilarToEscape(ctx *EvalContext, unescaped, pattern, escape string) (Datum, error) { - var escapeRune rune - if len(escape) > 0 { - var width int - escapeRune, width = utf8.DecodeRuneInString(escape) - if len(escape) > width { - return DBoolFalse, pgerror.Newf(pgcode.InvalidEscapeSequence, "invalid escape string") - } - } - key := similarToKey{s: pattern, escape: escapeRune} - return matchRegexpWithKey(ctx, NewDString(unescaped), key) -} - -type regexpKey struct { - s string - caseInsensitive bool -} - -// Pattern implements the RegexpCacheKey interface. -func (k regexpKey) Pattern() (string, error) { - if k.caseInsensitive { - return caseInsensitive(k.s), nil - } - return k.s, nil -} - -// SimilarEscape converts a SQL:2008 regexp pattern to POSIX style, so it can -// be used by our regexp engine. -func SimilarEscape(pattern string) string { - return similarEscapeCustomChar(pattern, '\\', true) -} - -// similarEscapeCustomChar converts a SQL:2008 regexp pattern to POSIX style, -// so it can be used by our regexp engine. This version of the function allows -// for a custom escape character. -// 'isEscapeNonEmpty' signals whether 'escapeChar' should be treated as empty. -func similarEscapeCustomChar(pattern string, escapeChar rune, isEscapeNonEmpty bool) string { - patternBuilder := make([]rune, 0, utf8.RuneCountInString(pattern)) - - inCharClass := false - afterEscape := false - numQuotes := 0 - for _, c := range pattern { - switch { - case afterEscape: - // For SUBSTRING patterns - if c == '"' && !inCharClass { - if numQuotes%2 == 0 { - patternBuilder = append(patternBuilder, '(') - } else { - patternBuilder = append(patternBuilder, ')') - } - numQuotes++ - } else if c == escapeChar && len(string(escapeChar)) > 1 { - // We encountered escaped escape unicode character represented by at least two bytes, - // so we keep only its single occurrence and need not to prepend it by '\'. - patternBuilder = append(patternBuilder, c) - } else { - patternBuilder = append(patternBuilder, '\\', c) - } - afterEscape = false - case utf8.ValidRune(escapeChar) && c == escapeChar && isEscapeNonEmpty: - // SQL99 escape character; do not immediately send to output - afterEscape = true - case inCharClass: - if c == '\\' { - patternBuilder = append(patternBuilder, '\\') - } - patternBuilder = append(patternBuilder, c) - if c == ']' { - inCharClass = false - } - case c == '[': - patternBuilder = append(patternBuilder, c) - inCharClass = true - case c == '%': - patternBuilder = append(patternBuilder, '.', '*') - case c == '_': - patternBuilder = append(patternBuilder, '.') - case c == '(': - // Convert to non-capturing parenthesis - patternBuilder = append(patternBuilder, '(', '?', ':') - case c == '\\', c == '.', c == '^', c == '$': - // Escape these characters because they are NOT - // metacharacters for SQL-style regexp - patternBuilder = append(patternBuilder, '\\', c) - default: - patternBuilder = append(patternBuilder, c) - } - } - - return string(patternBuilder) -} - -// caseInsensitive surrounds the transformed input string with -// (?i: ... ) -// which uses a non-capturing set of parens to turn a case sensitive -// regular expression pattern into a case insensitive regular -// expression pattern. -func caseInsensitive(pattern string) string { - return fmt.Sprintf("(?i:%s)", pattern) -} - -// anchorPattern surrounds the transformed input string with -// ^(?s: ... )$ -// which requires some explanation. We need "^" and "$" to force -// the pattern to match the entire input string as per SQL99 spec. -// The "(?:" and ")" are a non-capturing set of parens; we have to have -// parens in case the string contains "|", else the "^" and "$" will -// be bound into the first and last alternatives which is not what we -// want, and the parens must be non capturing because we don't want them -// to count when selecting output for SUBSTRING. -// "?s" turns on "dot all" mode meaning a dot will match any single character -// (without turning this mode on, the dot matches any single character except -// for line breaks). -func anchorPattern(pattern string, caseInsensitive bool) string { - if caseInsensitive { - return fmt.Sprintf("^(?si:%s)$", pattern) - } - return fmt.Sprintf("^(?s:%s)$", pattern) -} - -// FindEqualComparisonFunction looks up an overload of the "=" operator -// for a given pair of input operand types. -func FindEqualComparisonFunction(leftType, rightType *types.T) (TwoArgFn, bool) { - fn, found := CmpOps[EQ].LookupImpl(leftType, rightType) - if found { - return fn.Fn, true - } - return nil, false -} - -// IntPow computes the value of x^y. -func IntPow(x, y DInt) (*DInt, error) { - xd := apd.New(int64(x), 0) - yd := apd.New(int64(y), 0) - _, err := DecimalCtx.Pow(xd, xd, yd) - if err != nil { - return nil, err - } - i, err := xd.Int64() - if err != nil { - return nil, ErrIntOutOfRange - } - return NewDInt(DInt(i)), nil -} - -// PickFromTuple picks the greatest (or least value) from a tuple. -func PickFromTuple(ctx *EvalContext, greatest bool, args Datums) (Datum, error) { - g := args[0] - // Pick a greater (or smaller) value. - for _, d := range args[1:] { - var eval Datum - var err error - if greatest { - eval, err = evalComparison(ctx, LT, g, d) - } else { - eval, err = evalComparison(ctx, LT, d, g) - } - if err != nil { - return nil, err - } - if eval == DBoolTrue || - (eval == DNull && g == DNull) { - g = d - } - } - return g, nil -} - -// CallbackValueGenerator is a ValueGenerator that calls a supplied callback for -// producing the values. To be used with -// EvalContextTestingKnobs.CallbackGenerators. -type CallbackValueGenerator struct { - // cb is the callback to be called for producing values. It gets passed in 0 - // as prev initially, and the value it previously returned for subsequent - // invocations. Once it returns -1 or an error, it will not be invoked any - // more. - cb func(ctx context.Context, prev int, txn *kv.Txn) (int, error) - val int - txn *kv.Txn -} - -var _ ValueGenerator = &CallbackValueGenerator{} - -// NewCallbackValueGenerator creates a new CallbackValueGenerator. -func NewCallbackValueGenerator( - cb func(ctx context.Context, prev int, txn *kv.Txn) (int, error), -) *CallbackValueGenerator { - return &CallbackValueGenerator{ - cb: cb, - } -} - -// ResolvedType is part of the ValueGenerator interface. -func (c *CallbackValueGenerator) ResolvedType() *types.T { - return types.Int -} - -// Start is part of the ValueGenerator interface. -func (c *CallbackValueGenerator) Start(_ context.Context, txn *kv.Txn) error { - c.txn = txn - return nil -} - -// Next is part of the ValueGenerator interface. -func (c *CallbackValueGenerator) Next(ctx context.Context) (bool, error) { - var err error - c.val, err = c.cb(ctx, c.val, c.txn) - if err != nil { - return false, err - } - if c.val == -1 { - return false, nil - } - return true, nil -} - -// Values is part of the ValueGenerator interface. -func (c *CallbackValueGenerator) Values() (Datums, error) { - return Datums{NewDInt(DInt(c.val))}, nil -} - -// Close is part of the ValueGenerator interface. -func (c *CallbackValueGenerator) Close() {} - -// Sqrt returns the square root of x. -func Sqrt(x float64) (*DFloat, error) { - if x < 0.0 { - return nil, errSqrtOfNegNumber - } - return NewDFloat(DFloat(math.Sqrt(x))), nil -} - -// DecimalSqrt returns the square root of x. -func DecimalSqrt(x *apd.Decimal) (*DDecimal, error) { - if x.Sign() < 0 { - return nil, errSqrtOfNegNumber - } - dd := &DDecimal{} - _, err := DecimalCtx.Sqrt(&dd.Decimal, x) - return dd, err -} - -// Cbrt returns the cube root of x. -func Cbrt(x float64) (*DFloat, error) { - return NewDFloat(DFloat(math.Cbrt(x))), nil -} - -// DecimalCbrt returns the cube root of x. -func DecimalCbrt(x *apd.Decimal) (*DDecimal, error) { - dd := &DDecimal{} - _, err := DecimalCtx.Cbrt(&dd.Decimal, x) - return dd, err -} diff --git a/postgres/parser/sem/tree/expr.go b/postgres/parser/sem/tree/expr.go index bc2caa7fb2..dcc47add15 100644 --- a/postgres/parser/sem/tree/expr.go +++ b/postgres/parser/sem/tree/expr.go @@ -64,17 +64,6 @@ type Expr interface { // TypedExpr represents a well-typed expression. type TypedExpr interface { Expr - // Eval evaluates an SQL expression. Expression evaluation is a - // mostly straightforward walk over the parse tree. The only - // significant complexity is the handling of types and implicit - // conversions. See binOps and cmpOps for more details. Note that - // expression evaluation returns an error if certain node types are - // encountered: Placeholder, VarName (and related UnqualifiedStar, - // UnresolvedName and AllColumnsSelector) or Subquery. These nodes - // should be replaced prior to expression evaluation by an - // appropriate WalkExpr. For example, Placeholder should be replace - // by the argument passed from the client. - Eval(*EvalContext) (Datum, error) // ResolvedType provides the type of the TypedExpr, which is the type of Datum // that the TypedExpr will return when evaluated. ResolvedType() *types.T @@ -1109,11 +1098,6 @@ func (node *TypedDummy) TypeCheck(context.Context, *SemaContext, *types.T) (Type // Walk implements the Expr interface. func (node *TypedDummy) Walk(Visitor) Expr { return node } -// Eval implements the TypedExpr interface. -func (node *TypedDummy) Eval(*EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("should not eval typed dummy") -} - // BinaryOperator represents a binary operator. type BinaryOperator int @@ -1400,11 +1384,6 @@ func (node *FuncExpr) ResolvedOverload() *Overload { return node.fn } -// IsGeneratorApplication returns true iff the function applied is a generator (SRF). -func (node *FuncExpr) IsGeneratorApplication() bool { - return node.fn != nil && node.fn.Generator != nil -} - // IsWindowFunctionApplication returns true iff the function is being applied as a window function. func (node *FuncExpr) IsWindowFunctionApplication() bool { return node.WindowDef != nil diff --git a/postgres/parser/sem/tree/generators.go b/postgres/parser/sem/tree/generators.go deleted file mode 100644 index c8bf008756..0000000000 --- a/postgres/parser/sem/tree/generators.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2017 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -import ( - "context" - - "github.com/dolthub/doltgresql/postgres/parser/kv" - "github.com/dolthub/doltgresql/postgres/parser/types" -) - -// Table generators, also called "set-generating functions", are -// special functions that return an entire table. -// -// Overview of the concepts: -// -// - ValueGenerator is an interface that offers a -// Start/Next/Values/Stop API similar to sql.planNode. -// -// - because generators are regular functions, it is possible to use -// them in any expression context. This is useful to e.g -// pass an entire table as argument to the ARRAY( ) conversion -// function. -// -// - the data source mechanism in the sql package has a special case -// for generators appearing in FROM contexts and knows how to -// construct a special row source from them. - -// ValueGenerator is the interface provided by the value generator -// functions for SQL SRfs. Objects that implement this interface are -// able to produce rows of values in a streaming fashion (like Go -// iterators or generators in Python). -type ValueGenerator interface { - // ResolvedType returns the type signature of this value generator. - ResolvedType() *types.T - - // Start initializes the generator. Must be called once before - // Next() and Values(). It can be called again to restart - // the generator after Next() has returned false. - // - // txn represents the txn that the generator will run inside of. The generator - // is expected to hold on to this txn and use it in Next() calls. - Start(ctx context.Context, txn *kv.Txn) error - - // Next determines whether there is a row of data available. - Next(context.Context) (bool, error) - - // Values retrieves the current row of data. - Values() (Datums, error) - - // Close must be called after Start() before disposing of the - // ValueGenerator. It does not need to be called if Start() has not - // been called yet. It must not be called in-between restarts. - Close() -} - -// GeneratorFactory is the type of constructor functions for -// ValueGenerator objects. -type GeneratorFactory func(ctx *EvalContext, args Datums) (ValueGenerator, error) diff --git a/postgres/parser/sem/tree/indexed_vars.go b/postgres/parser/sem/tree/indexed_vars.go index d6e044233f..f404e8b149 100644 --- a/postgres/parser/sem/tree/indexed_vars.go +++ b/postgres/parser/sem/tree/indexed_vars.go @@ -37,7 +37,6 @@ import ( // IndexedVarContainer provides the implementation of TypeCheck, Eval, and // String for IndexedVars. type IndexedVarContainer interface { - IndexedVarEval(idx int, ctx *EvalContext) (Datum, error) IndexedVarResolvedType(idx int) *types.T // IndexedVarNodeFormatter returns a NodeFormatter; if an object that // wishes to implement this interface has lost the textual name that an @@ -86,15 +85,6 @@ func (v *IndexedVar) TypeCheck( return v, nil } -// Eval is part of the TypedExpr interface. -func (v *IndexedVar) Eval(ctx *EvalContext) (Datum, error) { - if ctx.IVarContainer == nil || ctx.IVarContainer == unboundContainer { - return nil, errors.AssertionFailedf( - "indexed var must be bound to a container before evaluation") - } - return ctx.IVarContainer.IndexedVarEval(v.Idx, ctx) -} - // ResolvedType is part of the TypedExpr interface. func (v *IndexedVar) ResolvedType() *types.T { if v.typ == nil { @@ -263,11 +253,6 @@ type unboundContainerType struct{} // is constant after parse). var unboundContainer = &unboundContainerType{} -// IndexedVarEval is part of the IndexedVarContainer interface. -func (*unboundContainerType) IndexedVarEval(idx int, _ *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("unbound ordinal reference @%d", idx+1) -} - // IndexedVarResolvedType is part of the IndexedVarContainer interface. func (*unboundContainerType) IndexedVarResolvedType(idx int) *types.T { panic(errors.AssertionFailedf("unbound ordinal reference @%d", idx+1)) @@ -284,11 +269,6 @@ type typeContainer struct { var _ IndexedVarContainer = &typeContainer{} -// IndexedVarEval is part of the IndexedVarContainer interface. -func (tc *typeContainer) IndexedVarEval(idx int, ctx *EvalContext) (Datum, error) { - return nil, errors.AssertionFailedf("no eval allowed in typeContainer") -} - // IndexedVarResolvedType is part of the IndexedVarContainer interface. func (tc *typeContainer) IndexedVarResolvedType(idx int) *types.T { return tc.types[idx] diff --git a/postgres/parser/sem/tree/normalize.go b/postgres/parser/sem/tree/normalize.go deleted file mode 100644 index 18c38c00c7..0000000000 --- a/postgres/parser/sem/tree/normalize.go +++ /dev/null @@ -1,1005 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2015 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -import ( - "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/json" - "github.com/dolthub/doltgresql/postgres/parser/types" -) - -type normalizableExpr interface { - Expr - normalize(*NormalizeVisitor) TypedExpr -} - -func (expr *CastExpr) normalize(v *NormalizeVisitor) TypedExpr { - return expr -} - -func (expr *CoalesceExpr) normalize(v *NormalizeVisitor) TypedExpr { - // This normalization checks whether COALESCE can be simplified - // based on constant expressions at the start of the COALESCE - // argument list. All known-null constant arguments are simply - // removed, and any known-nonnull constant argument before - // non-constant argument cause the entire COALESCE expression to - // collapse to that argument. - last := len(expr.Exprs) - 1 - for i := range expr.Exprs { - subExpr := expr.TypedExprAt(i) - - if i == last { - return subExpr - } - - if !v.isConst(subExpr) { - exprCopy := *expr - exprCopy.Exprs = expr.Exprs[i:] - return &exprCopy - } - - val, err := subExpr.Eval(v.ctx) - if err != nil { - v.err = err - return expr - } - - if val != DNull { - return subExpr - } - } - return expr -} - -func (expr *IfExpr) normalize(v *NormalizeVisitor) TypedExpr { - if v.isConst(expr.Cond) { - cond, err := expr.TypedCondExpr().Eval(v.ctx) - if err != nil { - v.err = err - return expr - } - if d, err := GetBool(cond); err == nil { - if d { - return expr.TypedTrueExpr() - } - return expr.TypedElseExpr() - } - return DNull - } - return expr -} - -func (expr *UnaryExpr) normalize(v *NormalizeVisitor) TypedExpr { - val := expr.TypedInnerExpr() - - if val == DNull { - return val - } - - switch expr.Operator { - case UnaryMinus: - // -0 -> 0 (except for float which has negative zero) - if val.ResolvedType().Family() != types.FloatFamily && v.isNumericZero(val) { - return val - } - switch b := val.(type) { - // -(a - b) -> (b - a) - case *BinaryExpr: - if b.Operator == Minus { - newBinExpr := newBinExprIfValidOverload(Minus, - b.TypedRight(), b.TypedLeft()) - if newBinExpr != nil { - newBinExpr.memoizeFn() - b = newBinExpr - } - return b - } - // - (- a) -> a - case *UnaryExpr: - if b.Operator == UnaryMinus { - return b.TypedInnerExpr() - } - } - } - - return expr -} - -func (expr *BinaryExpr) normalize(v *NormalizeVisitor) TypedExpr { - left := expr.TypedLeft() - right := expr.TypedRight() - expectedType := expr.ResolvedType() - - if !expr.Fn.NullableArgs && (left == DNull || right == DNull) { - return DNull - } - - var final TypedExpr - - switch expr.Operator { - case Plus: - if v.isNumericZero(right) { - final = ReType(left, expectedType) - break - } - if v.isNumericZero(left) { - final = ReType(right, expectedType) - break - } - case Minus: - if types.IsAdditiveType(left.ResolvedType()) && v.isNumericZero(right) { - final = ReType(left, expectedType) - break - } - case Mult: - if v.isNumericOne(right) { - final = ReType(left, expectedType) - break - } - if v.isNumericOne(left) { - final = ReType(right, expectedType) - break - } - // We can't simplify multiplication by zero to zero, - // because if the other operand is NULL during evaluation - // the result must be NULL. - case Div, FloorDiv: - if v.isNumericOne(right) { - final = ReType(left, expectedType) - break - } - } - - if final == nil { - return expr - } - return final -} - -func (expr *AndExpr) normalize(v *NormalizeVisitor) TypedExpr { - left := expr.TypedLeft() - right := expr.TypedRight() - var dleft, dright Datum - - if left == DNull && right == DNull { - return DNull - } - - // Use short-circuit evaluation to simplify AND expressions. - if v.isConst(left) { - dleft, v.err = left.Eval(v.ctx) - if v.err != nil { - return expr - } - if dleft != DNull { - if d, err := GetBool(dleft); err == nil { - if !d { - return dleft - } - return right - } - return DNull - } - return NewTypedAndExpr( - dleft, - right, - ) - } - if v.isConst(right) { - dright, v.err = right.Eval(v.ctx) - if v.err != nil { - return expr - } - if dright != DNull { - if d, err := GetBool(dright); err == nil { - if !d { - return right - } - return left - } - return DNull - } - return NewTypedAndExpr( - left, - dright, - ) - } - return expr -} - -func (expr *ComparisonExpr) normalize(v *NormalizeVisitor) TypedExpr { - switch expr.Operator { - case EQ, GE, GT, LE, LT: - // We want var nodes (VariableExpr, VarName, etc) to be immediate - // children of the comparison expression and not second or third - // children. That is, we want trees that look like: - // - // cmp cmp - // / \ / \ - // a op op a - // / \ / \ - // 1 2 1 2 - // - // Not trees that look like: - // - // cmp cmp cmp cmp - // / \ / \ / \ / \ - // op 2 op 2 1 op 1 op - // / \ / \ / \ / \ - // a 1 1 a a 2 2 a - // - // We loop attempting to simplify the comparison expression. As a - // pre-condition, we know there is at least one variable in the expression - // tree or we would not have entered this code path. - exprCopied := false - for { - if expr.TypedLeft() == DNull || expr.TypedRight() == DNull { - return DNull - } - - if v.isConst(expr.Left) { - switch expr.Right.(type) { - case *BinaryExpr, VariableExpr: - break - default: - return expr - } - - invertedOp, err := invertComparisonOp(expr.Operator) - if err != nil { - v.err = err - return expr - } - - // The left side is const and the right side is a binary expression or a - // variable. Flip the comparison op so that the right side is const and - // the left side is a binary expression or variable. - // Create a new ComparisonExpr so the function cache isn't reused. - if !exprCopied { - exprCopy := *expr - expr = &exprCopy - exprCopied = true - } - - expr = NewTypedComparisonExpr(invertedOp, expr.TypedRight(), expr.TypedLeft()) - } else if !v.isConst(expr.Right) { - return expr - } - - left, ok := expr.Left.(*BinaryExpr) - if !ok { - return expr - } - // The right is const and the left side is a binary expression. Rotate the - // comparison combining portions that are const. - - switch { - case v.isConst(left.Right) && - (left.Operator == Plus || left.Operator == Minus || left.Operator == Div): - - // cmp cmp - // / \ / \ - // [+-/] 2 -> a [-+*] - // / \ / \ - // a 1 2 1 - var op BinaryOperator - switch left.Operator { - case Plus: - op = Minus - case Minus: - op = Plus - case Div: - op = Mult - if expr.Operator != EQ { - // In this case, we must remember to *flip* the inequality if the - // divisor is negative, since we are in effect multiplying both sides - // of the inequality by a negative number. - divisor, err := left.TypedRight().Eval(v.ctx) - if err != nil { - v.err = err - return expr - } - if divisor.Compare(v.ctx, DZero) < 0 { - if !exprCopied { - exprCopy := *expr - expr = &exprCopy - exprCopied = true - } - - invertedOp, err := invertComparisonOp(expr.Operator) - if err != nil { - v.err = err - return expr - } - expr = NewTypedComparisonExpr(invertedOp, expr.TypedLeft(), expr.TypedRight()) - } - } - } - - newBinExpr := newBinExprIfValidOverload(op, - expr.TypedRight(), left.TypedRight()) - if newBinExpr == nil { - // Substitution is not possible type-wise. Nothing else to do. - break - } - - newRightExpr, err := newBinExpr.Eval(v.ctx) - if err != nil { - // In the case of an error during Eval, give up on normalizing this - // expression. There are some expected errors here if, for example, - // normalization produces a result that overflows an int64. - break - } - - if !exprCopied { - exprCopy := *expr - expr = &exprCopy - exprCopied = true - } - - expr.Left = left.Left - expr.Right = newRightExpr - expr.memoizeFn() - if !isVar(v.ctx, expr.Left, true /*allowConstPlaceholders*/) { - // Continue as long as the left side of the comparison is not a - // variable. - continue - } - - case v.isConst(left.Left) && (left.Operator == Plus || left.Operator == Minus): - // cmp cmp - // / \ / \ - // [+-] 2 -> [+-] a - // / \ / \ - // 1 a 1 2 - - op := expr.Operator - var newBinExpr *BinaryExpr - - switch left.Operator { - case Plus: - // - // (A + X) cmp B => X cmp (B - C) - // - newBinExpr = newBinExprIfValidOverload(Minus, - expr.TypedRight(), left.TypedLeft()) - case Minus: - // - // (A - X) cmp B => X cmp' (A - B) - // - newBinExpr = newBinExprIfValidOverload(Minus, - left.TypedLeft(), expr.TypedRight()) - op, v.err = invertComparisonOp(op) - if v.err != nil { - return expr - } - } - - if newBinExpr == nil { - break - } - - newRightExpr, err := newBinExpr.Eval(v.ctx) - if err != nil { - break - } - - if !exprCopied { - exprCopy := *expr - expr = &exprCopy - exprCopied = true - } - - expr.Operator = op - expr.Left = left.Right - expr.Right = newRightExpr - expr.memoizeFn() - if !isVar(v.ctx, expr.Left, true /*allowConstPlaceholders*/) { - // Continue as long as the left side of the comparison is not a - // variable. - continue - } - - case expr.Operator == EQ && left.Operator == JSONFetchVal && v.isConst(left.Right) && - v.isConst(expr.Right): - // This is a JSONB inverted index normalization, changing things of the form - // x->y=z to x @> {y:z} which can be used to build spans for inverted index - // lookups. - - if left.TypedRight().ResolvedType().Family() != types.StringFamily { - break - } - - str, err := left.TypedRight().Eval(v.ctx) - if err != nil { - break - } - // Check that we still have a string after evaluation. - if _, ok := str.(*DString); !ok { - break - } - - rhs, err := expr.TypedRight().Eval(v.ctx) - if err != nil { - break - } - - rjson := rhs.(*DJSON).JSON - t := rjson.Type() - if t == json.ObjectJSONType || t == json.ArrayJSONType { - // We can't make this transformation in cases like - // - // a->'b' = '["c"]', - // - // because containment is not equivalent to equality for non-scalar types. - break - } - - j := json.NewObjectBuilder(1) - j.Add(string(*str.(*DString)), rjson) - - dj, err := MakeDJSON(j.Build()) - if err != nil { - break - } - - typedJ, err := dj.TypeCheck(v.ctx.Context, nil, types.Jsonb) - if err != nil { - break - } - - return NewTypedComparisonExpr(Contains, left.TypedLeft(), typedJ) - } - - // We've run out of work to do. - break - } - case In, NotIn: - // If the right tuple in an In or NotIn comparison expression is constant, it can - // be normalized. - tuple, ok := expr.Right.(*DTuple) - if ok { - tupleCopy := *tuple - tupleCopy.Normalize(v.ctx) - - // If the tuple only contains NULL values, Normalize will have reduced - // it to a single NULL value. - if len(tupleCopy.D) == 1 && tupleCopy.D[0] == DNull { - return DNull - } - if len(tupleCopy.D) == 0 { - // NULL IN is false. - if expr.Operator == In { - return DBoolFalse - } - return DBoolTrue - } - if expr.TypedLeft() == DNull { - // NULL IN is NULL. - return DNull - } - - exprCopy := *expr - expr = &exprCopy - expr.Right = &tupleCopy - } - case IsDistinctFrom, IsNotDistinctFrom: - left := expr.TypedLeft() - right := expr.TypedRight() - - if v.isConst(left) && !v.isConst(right) { - // Switch operand order so that constant expression is on the right. - // This helps support index selection rules. - return NewTypedComparisonExpr(expr.Operator, right, left) - } - case NE, - Like, NotLike, - ILike, NotILike, - SimilarTo, NotSimilarTo, - RegMatch, NotRegMatch, - RegIMatch, NotRegIMatch, - Any, Some, All: - if expr.TypedLeft() == DNull || expr.TypedRight() == DNull { - return DNull - } - } - - return expr -} - -func (expr *OrExpr) normalize(v *NormalizeVisitor) TypedExpr { - left := expr.TypedLeft() - right := expr.TypedRight() - var dleft, dright Datum - - if left == DNull && right == DNull { - return DNull - } - - // Use short-circuit evaluation to simplify OR expressions. - if v.isConst(left) { - dleft, v.err = left.Eval(v.ctx) - if v.err != nil { - return expr - } - if dleft != DNull { - if d, err := GetBool(dleft); err == nil { - if d { - return dleft - } - return right - } - return DNull - } - return NewTypedOrExpr( - dleft, - right, - ) - } - if v.isConst(right) { - dright, v.err = right.Eval(v.ctx) - if v.err != nil { - return expr - } - if dright != DNull { - if d, err := GetBool(dright); err == nil { - if d { - return right - } - return left - } - return DNull - } - return NewTypedOrExpr( - left, - dright, - ) - } - return expr -} - -func (expr *NotExpr) normalize(v *NormalizeVisitor) TypedExpr { - inner := expr.TypedInnerExpr() - switch t := inner.(type) { - case *NotExpr: - return t.TypedInnerExpr() - } - return expr -} - -func (expr *ParenExpr) normalize(v *NormalizeVisitor) TypedExpr { - return expr.TypedInnerExpr() -} - -func (expr *AnnotateTypeExpr) normalize(v *NormalizeVisitor) TypedExpr { - // Type annotations have no runtime effect, so they can be removed after - // semantic analysis. - return expr.TypedInnerExpr() -} - -func (expr *RangeCond) normalize(v *NormalizeVisitor) TypedExpr { - leftFrom, from := expr.TypedLeftFrom(), expr.TypedFrom() - leftTo, to := expr.TypedLeftTo(), expr.TypedTo() - // The visitor hasn't walked down into leftTo; do it now. - if leftTo, v.err = v.ctx.NormalizeExpr(leftTo); v.err != nil { - return expr - } - - if (leftFrom == DNull || from == DNull) && (leftTo == DNull || to == DNull) { - return DNull - } - - leftCmp := GE - rightCmp := LE - if expr.Not { - leftCmp = LT - rightCmp = GT - } - - // "a BETWEEN b AND c" -> "a >= b AND a <= c" - // "a NOT BETWEEN b AND c" -> "a < b OR a > c" - transform := func(from, to TypedExpr) TypedExpr { - var newLeft, newRight TypedExpr - if from == DNull { - newLeft = DNull - } else { - newLeft = NewTypedComparisonExpr(leftCmp, leftFrom, from).normalize(v) - if v.err != nil { - return expr - } - } - if to == DNull { - newRight = DNull - } else { - newRight = NewTypedComparisonExpr(rightCmp, leftTo, to).normalize(v) - if v.err != nil { - return expr - } - } - if expr.Not { - return NewTypedOrExpr(newLeft, newRight).normalize(v) - } - return NewTypedAndExpr(newLeft, newRight).normalize(v) - } - - out := transform(from, to) - if expr.Symmetric { - if expr.Not { - // "a NOT BETWEEN SYMMETRIC b AND c" -> "(a < b OR a > c) AND (a < c OR a > b)" - out = NewTypedAndExpr(out, transform(to, from)).normalize(v) - } else { - // "a BETWEEN SYMMETRIC b AND c" -> "(a >= b AND a <= c) OR (a >= c OR a <= b)" - out = NewTypedOrExpr(out, transform(to, from)).normalize(v) - } - } - return out -} - -func (expr *Tuple) normalize(v *NormalizeVisitor) TypedExpr { - // A Tuple should be directly evaluated into a DTuple if it's either fully - // constant or contains only constants and top-level Placeholders. - isConst := true - for _, subExpr := range expr.Exprs { - if !v.isConst(subExpr) { - isConst = false - break - } - } - if !isConst { - return expr - } - e, err := expr.Eval(v.ctx) - if err != nil { - v.err = err - } - return e -} - -// NormalizeExpr normalizes a typed expression, simplifying where possible, -// but guaranteeing that the result of evaluating the expression is -// unchanged and that resulting expression tree is still well-typed. -// Example normalizations: -// -// (a) -> a -// a = 1 + 1 -> a = 2 -// a + 1 = 2 -> a = 1 -// a BETWEEN b AND c -> (a >= b) AND (a <= c) -// a NOT BETWEEN b AND c -> (a < b) OR (a > c) -func (ctx *EvalContext) NormalizeExpr(typedExpr TypedExpr) (TypedExpr, error) { - v := MakeNormalizeVisitor(ctx) - expr, _ := WalkExpr(&v, typedExpr) - if v.err != nil { - return nil, v.err - } - return expr.(TypedExpr), nil -} - -// NormalizeVisitor supports the execution of NormalizeExpr. -type NormalizeVisitor struct { - ctx *EvalContext - err error - - fastIsConstVisitor fastIsConstVisitor -} - -var _ Visitor = &NormalizeVisitor{} - -// MakeNormalizeVisitor creates a NormalizeVisitor instance. -func MakeNormalizeVisitor(ctx *EvalContext) NormalizeVisitor { - return NormalizeVisitor{ctx: ctx, fastIsConstVisitor: fastIsConstVisitor{ctx: ctx}} -} - -// Err retrieves the error field in the NormalizeVisitor. -func (v *NormalizeVisitor) Err() error { return v.err } - -// VisitPre implements the Visitor interface. -func (v *NormalizeVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) { - if v.err != nil { - return false, expr - } - - switch expr.(type) { - case *Subquery: - // Subqueries are pre-normalized during semantic analysis. There - // is nothing to do here. - return false, expr - } - - return true, expr -} - -// VisitPost implements the Visitor interface. -func (v *NormalizeVisitor) VisitPost(expr Expr) Expr { - if v.err != nil { - return expr - } - // We don't propagate errors during this step because errors might involve a - // branch of code that isn't traversed by normal execution (for example, - // IF(2 = 2, 1, 1 / 0)). - - // Normalize expressions that know how to normalize themselves. - if normalizable, ok := expr.(normalizableExpr); ok { - expr = normalizable.normalize(v) - if v.err != nil { - return expr - } - } - - // Evaluate all constant expressions. - if v.isConst(expr) { - value, err := expr.(TypedExpr).Eval(v.ctx) - if err != nil { - // Ignore any errors here (e.g. division by zero), so they can happen - // during execution where they are correctly handled. Note that in some - // cases we might not even get an error (if this particular expression - // does not get evaluated when the query runs, e.g. it's inside a CASE). - return expr - } - if value == DNull { - // We don't want to return an expression that has a different type; cast - // the NULL if necessary. - return ReType(DNull, expr.(TypedExpr).ResolvedType()) - } - return value - } - - return expr -} - -func (v *NormalizeVisitor) isConst(expr Expr) bool { - return v.fastIsConstVisitor.run(expr) -} - -// isNumericZero returns true if the datum is a number and equal to -// zero. -func (v *NormalizeVisitor) isNumericZero(expr TypedExpr) bool { - if d, ok := expr.(Datum); ok { - switch t := UnwrapDatum(v.ctx, d).(type) { - case *DDecimal: - return t.Decimal.Sign() == 0 - case *DFloat: - return *t == 0 - case *DInt: - return *t == 0 - } - } - return false -} - -// isNumericOne returns true if the datum is a number and equal to -// one. -func (v *NormalizeVisitor) isNumericOne(expr TypedExpr) bool { - if d, ok := expr.(Datum); ok { - switch t := UnwrapDatum(v.ctx, d).(type) { - case *DDecimal: - return t.Decimal.Cmp(&DecimalOne.Decimal) == 0 - case *DFloat: - return *t == 1.0 - case *DInt: - return *t == 1 - } - } - return false -} - -func invertComparisonOp(op ComparisonOperator) (ComparisonOperator, error) { - switch op { - case EQ: - return EQ, nil - case GE: - return LE, nil - case GT: - return LT, nil - case LE: - return GE, nil - case LT: - return GT, nil - default: - return op, errors.AssertionFailedf("unable to invert: %s", op) - } -} - -type isConstVisitor struct { - ctx *EvalContext - isConst bool -} - -var _ Visitor = &isConstVisitor{} - -func (v *isConstVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) { - if v.isConst { - if !operatorIsImmutable(expr) || isVar(v.ctx, expr, true /*allowConstPlaceholders*/) { - v.isConst = false - return false, expr - } - } - return true, expr -} - -func operatorIsImmutable(expr Expr) bool { - switch t := expr.(type) { - case *FuncExpr: - return t.fnProps.Class == NormalClass && t.fn.Volatility <= VolatilityImmutable - - case *CastExpr: - volatility, ok := LookupCastVolatility(t.Expr.(TypedExpr).ResolvedType(), t.typ) - return ok && volatility <= VolatilityImmutable - - case *UnaryExpr: - return t.fn.Volatility <= VolatilityImmutable - - case *BinaryExpr: - return t.Fn.Volatility <= VolatilityImmutable - - case *ComparisonExpr: - return t.Fn.Volatility <= VolatilityImmutable - - default: - return true - } -} - -func (*isConstVisitor) VisitPost(expr Expr) Expr { return expr } - -func (v *isConstVisitor) run(expr Expr) bool { - v.isConst = true - WalkExprConst(v, expr) - return v.isConst -} - -// IsConst returns whether the expression is constant. A constant expression -// does not contain variables, as defined by ContainsVars, nor impure functions. -func IsConst(evalCtx *EvalContext, expr TypedExpr) bool { - v := isConstVisitor{ctx: evalCtx} - return v.run(expr) -} - -// fastIsConstVisitor is similar to isConstVisitor, but it only visits -// at most two levels of the tree (with one exception, see below). -// In essence, it determines whether an expression is constant by checking -// whether its children are const Datums. -// -// This can be used during normalization since constants are evaluated -// bottom-up. If a child is *not* a const Datum, that means it was already -// determined to be non-constant, and therefore was not evaluated. -type fastIsConstVisitor struct { - ctx *EvalContext - isConst bool - - // visited indicates whether we have already visited one level of the tree. - // fastIsConstVisitor only visits at most two levels of the tree, with one - // exception: If the second level has a Cast expression, fastIsConstVisitor - // may visit three levels. - visited bool -} - -var _ Visitor = &fastIsConstVisitor{} - -func (v *fastIsConstVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) { - if v.visited { - if _, ok := expr.(*CastExpr); ok { - // We recurse one more time for cast expressions, since the - // NormalizeVisitor may have wrapped a NULL. - return true, expr - } - if _, ok := expr.(Datum); !ok || isVar(v.ctx, expr, true /*allowConstPlaceholders*/) { - // If the child expression is not a const Datum, the parent expression is - // not constant. Note that all constant literals have already been - // normalized to Datum in TypeCheck. - v.isConst = false - } - return false, expr - } - v.visited = true - - // If the parent expression is a variable or non-immutable operator, we know - // that it is not constant. - - if !operatorIsImmutable(expr) || isVar(v.ctx, expr, true /*allowConstPlaceholders*/) { - v.isConst = false - return false, expr - } - - return true, expr -} - -func (*fastIsConstVisitor) VisitPost(expr Expr) Expr { return expr } - -func (v *fastIsConstVisitor) run(expr Expr) bool { - v.isConst = true - v.visited = false - WalkExprConst(v, expr) - return v.isConst -} - -// isVar returns true if the expression's value can vary during plan -// execution. The parameter allowConstPlaceholders should be true -// in the common case of scalar expressions that will be evaluated -// in the context of the execution of a prepared query, where the -// placeholder will have the same value for every row processed. -// It is set to false for scalar expressions that are not -// evaluated as part of query execution, eg. DEFAULT expressions. -func isVar(evalCtx *EvalContext, expr Expr, allowConstPlaceholders bool) bool { - switch expr.(type) { - case VariableExpr: - return true - case *Placeholder: - if allowConstPlaceholders { - if evalCtx == nil || !evalCtx.HasPlaceholders() { - // The placeholder cannot be resolved -- it is variable. - return true - } - return evalCtx.Placeholders.IsUnresolvedPlaceholder(expr) - } - // Placeholders considered always variable. - return true - } - return false -} - -type containsVarsVisitor struct { - containsVars bool -} - -var _ Visitor = &containsVarsVisitor{} - -func (v *containsVarsVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) { - if !v.containsVars && isVar(nil, expr, false /*allowConstPlaceholders*/) { - v.containsVars = true - } - if v.containsVars { - return false, expr - } - return true, expr -} - -func (*containsVarsVisitor) VisitPost(expr Expr) Expr { return expr } - -// ContainsVars returns true if the expression contains any variables. -// (variables = sub-expressions, placeholders, indexed vars, etc.) -func ContainsVars(expr Expr) bool { - v := containsVarsVisitor{containsVars: false} - WalkExprConst(&v, expr) - return v.containsVars -} - -// DecimalOne represents the constant 1 as DECIMAL. -var DecimalOne DDecimal - -func init() { - DecimalOne.SetInt64(1) -} - -// ReType ensures that the given numeric expression evaluates -// to the requested type, inserting a cast if necessary. -func ReType(expr TypedExpr, wantedType *types.T) TypedExpr { - if wantedType.Family() == types.AnyFamily || expr.ResolvedType().Identical(wantedType) { - return expr - } - res := &CastExpr{Expr: expr, Type: wantedType} - res.typ = wantedType - return res -} diff --git a/postgres/parser/sem/tree/overload.go b/postgres/parser/sem/tree/overload.go index 3801f2bc61..d566e3cf4f 100644 --- a/postgres/parser/sem/tree/overload.go +++ b/postgres/parser/sem/tree/overload.go @@ -77,15 +77,6 @@ type Overload struct { // might be more appropriate. Info string - AggregateFunc func([]*types.T, *EvalContext, Datums) AggregateFunc - WindowFunc func([]*types.T, *EvalContext) WindowFunc - Fn func(*EvalContext, Datums) (Datum, error) - Generator GeneratorFactory - - // SQLFn must be set for overloads of type SQLClass. It should return a SQL - // statement which will be executed as a common table expression in the query. - SQLFn func(*EvalContext, Datums) (string, error) - // SpecializedVecBuiltin is used to let the vectorized engine // know when an Overload has a specialized vectorized operator. SpecializedVecBuiltin SpecializedVectorizedBuiltin diff --git a/postgres/parser/sem/tree/pgwire_encode.go b/postgres/parser/sem/tree/pgwire_encode.go index 2d6d584efb..1ea85c3442 100644 --- a/postgres/parser/sem/tree/pgwire_encode.go +++ b/postgres/parser/sem/tree/pgwire_encode.go @@ -47,7 +47,7 @@ func (d *DTuple) pgwireFormat(ctx *FmtCtx) { comma := "" for _, v := range d.D { ctx.WriteString(comma) - switch dv := UnwrapDatum(nil, v).(type) { + switch dv := UnwrapDatum(v).(type) { case dNull: case *DString: pgwireFormatStringInTuple(&ctx.Buffer, string(*dv)) @@ -120,7 +120,7 @@ func (d *DArray) pgwireFormat(ctx *FmtCtx) { comma := "" for _, v := range d.Array { ctx.WriteString(comma) - switch dv := UnwrapDatum(nil, v).(type) { + switch dv := UnwrapDatum(v).(type) { case dNull: ctx.WriteString("NULL") case *DString: diff --git a/postgres/parser/sem/tree/regexp_cache.go b/postgres/parser/sem/tree/regexp_cache.go deleted file mode 100644 index 4e5fc5c010..0000000000 --- a/postgres/parser/sem/tree/regexp_cache.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2015 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -import ( - "regexp" - - "github.com/dolthub/doltgresql/postgres/parser/cache" - "github.com/dolthub/doltgresql/postgres/parser/syncutil" -) - -// RegexpCacheKey allows cache keys to take the form of different types, -// as long as they are comparable and can produce a pattern when needed -// for regexp compilation. The pattern method will not be called until -// after a cache lookup is performed and the result is a miss. -type RegexpCacheKey interface { - Pattern() (string, error) -} - -// A RegexpCache is a cache used to store compiled regular expressions. -// The cache is safe for concurrent use by multiple goroutines. It is also -// safe to use the cache through a nil reference, where it will act like a valid -// cache with no capacity. -type RegexpCache struct { - mu syncutil.Mutex - cache *cache.UnorderedCache -} - -// NewRegexpCache creates a new RegexpCache of the given size. -// The underlying cache internally uses a hash map, so lookups -// are cheap. -func NewRegexpCache(size int) *RegexpCache { - return &RegexpCache{ - cache: cache.NewUnorderedCache(cache.Config{ - Policy: cache.CacheLRU, - ShouldEvict: func(s int, key, value interface{}) bool { - return s > size - }, - }), - } -} - -// GetRegexp consults the cache for the regular expressions stored for -// the given key, compiling the key's pattern if it is not already -// in the cache. -func (rc *RegexpCache) GetRegexp(key RegexpCacheKey) (*regexp.Regexp, error) { - if rc != nil { - re := rc.lookup(key) - if re != nil { - return re, nil - } - } - - pattern, err := key.Pattern() - if err != nil { - return nil, err - } - - re, err := regexp.Compile(pattern) - if err != nil { - return nil, err - } - - if rc != nil { - rc.update(key, re) - } - return re, nil -} - -// lookup checks for the regular expression in the cache in a -// synchronized manner, returning it if it exists. -func (rc *RegexpCache) lookup(key RegexpCacheKey) *regexp.Regexp { - rc.mu.Lock() - defer rc.mu.Unlock() - v, ok := rc.cache.Get(key) - if !ok { - return nil - } - return v.(*regexp.Regexp) -} - -// update invalidates the regular expression for the given pattern. -// If a new regular expression is passed in, it is inserted into the cache. -func (rc *RegexpCache) update(key RegexpCacheKey, re *regexp.Regexp) { - rc.mu.Lock() - defer rc.mu.Unlock() - rc.cache.Del(key) - if re != nil { - rc.cache.Add(key, re) - } -} - -// Len returns the number of compiled regular expressions in the cache. -func (rc *RegexpCache) Len() int { - if rc == nil { - return 0 - } - rc.mu.Lock() - defer rc.mu.Unlock() - return rc.cache.Len() -} diff --git a/postgres/parser/sem/tree/window_funcs.go b/postgres/parser/sem/tree/window_funcs.go deleted file mode 100644 index 0a18fdb7ce..0000000000 --- a/postgres/parser/sem/tree/window_funcs.go +++ /dev/null @@ -1,665 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2017 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -import ( - "context" - "sort" - - "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/encoding" - "github.com/dolthub/doltgresql/postgres/parser/types" -) - -// IndexedRows are rows with the corresponding indices. -type IndexedRows interface { - Len() int // returns number of rows - GetRow(ctx context.Context, idx int) (IndexedRow, error) // returns a row at the given index or an error -} - -// IndexedRow is a row with a corresponding index. -type IndexedRow interface { - GetIdx() int // returns index of the row - GetDatum(idx int) (Datum, error) // returns a datum at the given index - GetDatums(startIdx, endIdx int) (Datums, error) // returns datums at indices [startIdx, endIdx) -} - -// WindowFrameRun contains the runtime state of window frame during calculations. -type WindowFrameRun struct { - // constant for all calls to WindowFunc.Add - Rows IndexedRows - ArgsIdxs []uint32 // indices of the arguments to the window function - Frame *WindowFrame // If non-nil, Frame represents the frame specification of this window. If nil, default frame is used. - StartBoundOffset Datum - EndBoundOffset Datum - FilterColIdx int - OrdColIdx int // Column over which rows are ordered within the partition. It is only required in RANGE mode. - OrdDirection encoding.Direction // Direction of the ordering over OrdColIdx. - PlusOp, MinusOp *BinOp // Binary operators for addition and subtraction required only in RANGE mode. - PeerHelper PeerGroupsIndicesHelper - - // Any error that occurred within methods that cannot return an error (like - // within a closure that is passed into sort.Search()). - err error - - // changes for each peer group - CurRowPeerGroupNum int // the number of the current row's peer group - - // changes for each row (each call to WindowFunc.Add) - RowIdx int // the current row index -} - -// WindowFrameRangeOps allows for looking up an implementation of binary -// operators necessary for RANGE mode of framing. -type WindowFrameRangeOps struct{} - -// LookupImpl looks up implementation of Plus and Minus binary operators for -// provided left and right types and returns them along with a boolean which -// indicates whether lookup is successful. -func (o WindowFrameRangeOps) LookupImpl(left, right *types.T) (*BinOp, *BinOp, bool) { - plusOverloads, minusOverloads := BinOps[Plus], BinOps[Minus] - plusOp, found := plusOverloads.lookupImpl(left, right) - if !found { - return nil, nil, false - } - minusOp, found := minusOverloads.lookupImpl(left, right) - if !found { - return nil, nil, false - } - return plusOp, minusOp, true -} - -// getValueByOffset returns a datum calculated as the value of the current row -// in the column over which rows are ordered plus/minus logic offset, and an -// error if encountered. It should be used only in RANGE mode. -func (wfr *WindowFrameRun) getValueByOffset( - ctx context.Context, evalCtx *EvalContext, offset Datum, negative bool, -) (Datum, error) { - if wfr.OrdDirection == encoding.Descending { - // If rows are in descending order, we want to perform the "opposite" - // addition/subtraction to default ascending order. - negative = !negative - } - var binOp *BinOp - if negative { - binOp = wfr.MinusOp - } else { - binOp = wfr.PlusOp - } - value, err := wfr.valueAt(ctx, wfr.RowIdx) - if err != nil { - return nil, err - } - if value == DNull { - return DNull, nil - } - return binOp.Fn(evalCtx, value, offset) -} - -// FrameStartIdx returns the index of starting row in the frame (which is the first to be included). -func (wfr *WindowFrameRun) FrameStartIdx(ctx context.Context, evalCtx *EvalContext) (int, error) { - if wfr.Frame == nil { - return 0, nil - } - switch wfr.Frame.Mode { - case RANGE: - switch wfr.Frame.Bounds.StartBound.BoundType { - case UnboundedPreceding: - return 0, nil - case OffsetPreceding: - value, err := wfr.getValueByOffset(ctx, evalCtx, wfr.StartBoundOffset, true /* negative */) - if err != nil { - return 0, err - } - if wfr.OrdDirection == encoding.Descending { - // We use binary search on [0, wfr.RowIdx) interval to find the first row - // whose value is smaller or equal to 'value'. If such row is not found, - // then Search will correctly return wfr.RowIdx. - return sort.Search(wfr.RowIdx, func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) <= 0 - }), wfr.err - } - // We use binary search on [0, wfr.RowIdx) interval to find the first row - // whose value is greater or equal to 'value'. If such row is not found, - // then Search will correctly return wfr.RowIdx. - return sort.Search(wfr.RowIdx, func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) >= 0 - }), wfr.err - case CurrentRow: - // Spec: in RANGE mode CURRENT ROW means that the frame starts with the current row's first peer. - return wfr.PeerHelper.GetFirstPeerIdx(wfr.CurRowPeerGroupNum), nil - case OffsetFollowing: - value, err := wfr.getValueByOffset(ctx, evalCtx, wfr.StartBoundOffset, false /* negative */) - if err != nil { - return 0, err - } - if wfr.OrdDirection == encoding.Descending { - // We use binary search on [0, wfr.PartitionSize()) interval to find - // the first row whose value is smaller or equal to 'value'. - return sort.Search(wfr.PartitionSize(), func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) <= 0 - }), wfr.err - } - // We use binary search on [0, wfr.PartitionSize()) interval to find the - // first row whose value is greater or equal to 'value'. - return sort.Search(wfr.PartitionSize(), func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) >= 0 - }), wfr.err - default: - return 0, errors.AssertionFailedf( - "unexpected WindowFrameBoundType in RANGE mode: %d", - wfr.Frame.Bounds.StartBound.BoundType) - } - case ROWS: - switch wfr.Frame.Bounds.StartBound.BoundType { - case UnboundedPreceding: - return 0, nil - case OffsetPreceding: - offset := MustBeDInt(wfr.StartBoundOffset) - idx := wfr.RowIdx - int(offset) - if idx < 0 { - idx = 0 - } - return idx, nil - case CurrentRow: - return wfr.RowIdx, nil - case OffsetFollowing: - offset := MustBeDInt(wfr.StartBoundOffset) - idx := wfr.RowIdx + int(offset) - if idx >= wfr.PartitionSize() { - idx = wfr.unboundedFollowing() - } - return idx, nil - default: - return 0, errors.AssertionFailedf( - "unexpected WindowFrameBoundType in ROWS mode: %d", - wfr.Frame.Bounds.StartBound.BoundType) - } - case GROUPS: - switch wfr.Frame.Bounds.StartBound.BoundType { - case UnboundedPreceding: - return 0, nil - case OffsetPreceding: - offset := MustBeDInt(wfr.StartBoundOffset) - peerGroupNum := wfr.CurRowPeerGroupNum - int(offset) - if peerGroupNum < 0 { - peerGroupNum = 0 - } - return wfr.PeerHelper.GetFirstPeerIdx(peerGroupNum), nil - case CurrentRow: - // Spec: in GROUPS mode CURRENT ROW means that the frame starts with the current row's first peer. - return wfr.PeerHelper.GetFirstPeerIdx(wfr.CurRowPeerGroupNum), nil - case OffsetFollowing: - offset := MustBeDInt(wfr.StartBoundOffset) - peerGroupNum := wfr.CurRowPeerGroupNum + int(offset) - lastPeerGroupNum := wfr.PeerHelper.GetLastPeerGroupNum() - if peerGroupNum > lastPeerGroupNum || peerGroupNum < 0 { - // peerGroupNum is out of bounds, so we return the index of the first - // row after the partition. - return wfr.unboundedFollowing(), nil - } - return wfr.PeerHelper.GetFirstPeerIdx(peerGroupNum), nil - default: - return 0, errors.AssertionFailedf( - "unexpected WindowFrameBoundType in GROUPS mode: %d", - wfr.Frame.Bounds.StartBound.BoundType) - } - default: - return 0, errors.AssertionFailedf("unexpected WindowFrameMode: %d", wfr.Frame.Mode) - } -} - -// IsDefaultFrame returns whether a frame equivalent to the default frame -// is being used (default is RANGE UNBOUNDED PRECEDING). -func (f *WindowFrame) IsDefaultFrame() bool { - if f == nil { - return true - } - if f.Bounds.StartBound.BoundType == UnboundedPreceding { - return f.DefaultFrameExclusion() && f.Mode == RANGE && - (f.Bounds.EndBound == nil || f.Bounds.EndBound.BoundType == CurrentRow) - } - return false -} - -// DefaultFrameExclusion returns true if optional frame exclusion is omitted. -func (f *WindowFrame) DefaultFrameExclusion() bool { - return f == nil || f.Exclusion == NoExclusion -} - -// FrameEndIdx returns the index of the first row after the frame. -func (wfr *WindowFrameRun) FrameEndIdx(ctx context.Context, evalCtx *EvalContext) (int, error) { - if wfr.Frame == nil { - return wfr.DefaultFrameSize(), nil - } - switch wfr.Frame.Mode { - case RANGE: - if wfr.Frame.Bounds.EndBound == nil { - // We're using default value of CURRENT ROW when EndBound is omitted. - // Spec: in RANGE mode CURRENT ROW means that the frame ends with the current row's last peer. - return wfr.DefaultFrameSize(), nil - } - switch wfr.Frame.Bounds.EndBound.BoundType { - case OffsetPreceding: - value, err := wfr.getValueByOffset(ctx, evalCtx, wfr.EndBoundOffset, true /* negative */) - if err != nil { - return 0, err - } - if wfr.OrdDirection == encoding.Descending { - // We use binary search on [0, wfr.PartitionSize()) interval to find - // the first row whose value is smaller than 'value'. If such row is - // not found, then Search will correctly return wfr.PartitionSize(). - // Note that searching up to wfr.RowIdx is not correct in case of a - // zero offset (we need to include all peers of the current row). - return sort.Search(wfr.PartitionSize(), func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) < 0 - }), wfr.err - } - // We use binary search on [0, wfr.PartitionSize()) interval to find - // the first row whose value is smaller than 'value'. If such row is - // not found, then Search will correctly return wfr.PartitionSize(). - // Note that searching up to wfr.RowIdx is not correct in case of a - // zero offset (we need to include all peers of the current row). - return sort.Search(wfr.PartitionSize(), func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) > 0 - }), wfr.err - case CurrentRow: - // Spec: in RANGE mode CURRENT ROW means that the frame end with the current row's last peer. - return wfr.DefaultFrameSize(), nil - case OffsetFollowing: - value, err := wfr.getValueByOffset(ctx, evalCtx, wfr.EndBoundOffset, false /* negative */) - if err != nil { - return 0, err - } - if wfr.OrdDirection == encoding.Descending { - // We use binary search on [0, wfr.PartitionSize()) interval to find - // the first row whose value is smaller than 'value'. - return sort.Search(wfr.PartitionSize(), func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) < 0 - }), wfr.err - } - // We use binary search on [0, wfr.PartitionSize()) interval to find - // the first row whose value is smaller than 'value'. - return sort.Search(wfr.PartitionSize(), func(i int) bool { - if wfr.err != nil { - return false - } - valueAt, err := wfr.valueAt(ctx, i) - if err != nil { - wfr.err = err - return false - } - return valueAt.Compare(evalCtx, value) > 0 - }), wfr.err - case UnboundedFollowing: - return wfr.unboundedFollowing(), nil - default: - return 0, errors.AssertionFailedf( - "unexpected WindowFrameBoundType in RANGE mode: %d", - wfr.Frame.Bounds.EndBound.BoundType) - } - case ROWS: - if wfr.Frame.Bounds.EndBound == nil { - // We're using default value of CURRENT ROW when EndBound is omitted. - return wfr.RowIdx + 1, nil - } - switch wfr.Frame.Bounds.EndBound.BoundType { - case OffsetPreceding: - offset := MustBeDInt(wfr.EndBoundOffset) - idx := wfr.RowIdx - int(offset) + 1 - if idx < 0 { - idx = 0 - } - return idx, nil - case CurrentRow: - return wfr.RowIdx + 1, nil - case OffsetFollowing: - offset := MustBeDInt(wfr.EndBoundOffset) - idx := wfr.RowIdx + int(offset) + 1 - if idx >= wfr.PartitionSize() { - idx = wfr.unboundedFollowing() - } - return idx, nil - case UnboundedFollowing: - return wfr.unboundedFollowing(), nil - default: - return 0, errors.AssertionFailedf( - "unexpected WindowFrameBoundType in ROWS mode: %d", - wfr.Frame.Bounds.EndBound.BoundType) - } - case GROUPS: - if wfr.Frame.Bounds.EndBound == nil { - // We're using default value of CURRENT ROW when EndBound is omitted. - // Spec: in GROUPS mode CURRENT ROW means that the frame ends with the current row's last peer. - return wfr.DefaultFrameSize(), nil - } - switch wfr.Frame.Bounds.EndBound.BoundType { - case OffsetPreceding: - offset := MustBeDInt(wfr.EndBoundOffset) - peerGroupNum := wfr.CurRowPeerGroupNum - int(offset) - if peerGroupNum < 0 { - // EndBound's peer group is "outside" of the partition. - return 0, nil - } - return wfr.PeerHelper.GetFirstPeerIdx(peerGroupNum) + wfr.PeerHelper.GetRowCount(peerGroupNum), nil - case CurrentRow: - return wfr.DefaultFrameSize(), nil - case OffsetFollowing: - offset := MustBeDInt(wfr.EndBoundOffset) - peerGroupNum := wfr.CurRowPeerGroupNum + int(offset) - lastPeerGroupNum := wfr.PeerHelper.GetLastPeerGroupNum() - if peerGroupNum > lastPeerGroupNum || peerGroupNum < 0 { - // peerGroupNum is out of bounds, so we return the index of the first - // row after the partition. - return wfr.unboundedFollowing(), nil - } - return wfr.PeerHelper.GetFirstPeerIdx(peerGroupNum) + wfr.PeerHelper.GetRowCount(peerGroupNum), nil - case UnboundedFollowing: - return wfr.unboundedFollowing(), nil - default: - return 0, errors.AssertionFailedf( - "unexpected WindowFrameBoundType in GROUPS mode: %d", - wfr.Frame.Bounds.EndBound.BoundType) - } - default: - return 0, errors.AssertionFailedf( - "unexpected WindowFrameMode: %d", wfr.Frame.Mode) - } -} - -// FrameSize returns the number of rows in the current frame (taking into -// account - if present - a filter and a frame exclusion). -func (wfr *WindowFrameRun) FrameSize(ctx context.Context, evalCtx *EvalContext) (int, error) { - if wfr.Frame == nil { - return wfr.DefaultFrameSize(), nil - } - frameEndIdx, err := wfr.FrameEndIdx(ctx, evalCtx) - if err != nil { - return 0, err - } - frameStartIdx, err := wfr.FrameStartIdx(ctx, evalCtx) - if err != nil { - return 0, err - } - size := frameEndIdx - frameStartIdx - if !wfr.noFilter() || !wfr.Frame.DefaultFrameExclusion() { - size = 0 - for idx := frameStartIdx; idx < frameEndIdx; idx++ { - if skipped, err := wfr.IsRowSkipped(ctx, idx); err != nil { - return 0, err - } else if skipped { - continue - } - size++ - } - } - if size <= 0 { - size = 0 - } - return size, nil -} - -// Rank returns the rank of the current row. -func (wfr *WindowFrameRun) Rank() int { - return wfr.RowIdx + 1 -} - -// PartitionSize returns the number of rows in the current partition. -func (wfr *WindowFrameRun) PartitionSize() int { - return wfr.Rows.Len() -} - -// unboundedFollowing returns the index of the "first row beyond" the partition -// so that current frame contains all the rows till the end of the partition. -func (wfr *WindowFrameRun) unboundedFollowing() int { - return wfr.PartitionSize() -} - -// DefaultFrameSize returns the size of default window frame which contains -// the rows from the start of the partition through the last peer of the current row. -func (wfr *WindowFrameRun) DefaultFrameSize() int { - return wfr.PeerHelper.GetFirstPeerIdx(wfr.CurRowPeerGroupNum) + wfr.PeerHelper.GetRowCount(wfr.CurRowPeerGroupNum) -} - -// FirstInPeerGroup returns if the current row is the first in its peer group. -func (wfr *WindowFrameRun) FirstInPeerGroup() bool { - return wfr.RowIdx == wfr.PeerHelper.GetFirstPeerIdx(wfr.CurRowPeerGroupNum) -} - -// Args returns the current argument set in the window frame. -func (wfr *WindowFrameRun) Args(ctx context.Context) (Datums, error) { - return wfr.ArgsWithRowOffset(ctx, 0) -} - -// ArgsWithRowOffset returns the argument set at the given offset in the window frame. -func (wfr *WindowFrameRun) ArgsWithRowOffset(ctx context.Context, offset int) (Datums, error) { - return wfr.ArgsByRowIdx(ctx, wfr.RowIdx+offset) -} - -// ArgsByRowIdx returns the argument set of the row at idx. -func (wfr *WindowFrameRun) ArgsByRowIdx(ctx context.Context, idx int) (Datums, error) { - row, err := wfr.Rows.GetRow(ctx, idx) - if err != nil { - return nil, err - } - datums := make(Datums, len(wfr.ArgsIdxs)) - for i, argIdx := range wfr.ArgsIdxs { - datums[i], err = row.GetDatum(int(argIdx)) - if err != nil { - return nil, err - } - } - return datums, nil -} - -// valueAt returns the first argument of the window function at the row idx. -func (wfr *WindowFrameRun) valueAt(ctx context.Context, idx int) (Datum, error) { - row, err := wfr.Rows.GetRow(ctx, idx) - if err != nil { - return nil, err - } - return row.GetDatum(wfr.OrdColIdx) -} - -// RangeModeWithOffsets returns whether the frame is in RANGE mode with at least -// one of the bounds containing an offset. -func (wfr *WindowFrameRun) RangeModeWithOffsets() bool { - return wfr.Frame.Mode == RANGE && wfr.Frame.Bounds.HasOffset() -} - -// FullPartitionIsInWindow checks whether we have such a window frame that all -// rows of the partition are inside of the window for each of the rows. -func (wfr *WindowFrameRun) FullPartitionIsInWindow() bool { - // Note that we do not need to check whether a filter is present because - // application of the filter to a row does not depend on the position of the - // row or whether it is inside of the window frame. - if wfr.Frame == nil || !wfr.Frame.DefaultFrameExclusion() { - return false - } - if wfr.Frame.Bounds.EndBound == nil { - // If the end bound is omitted, it is CURRENT ROW (the default value) which - // doesn't guarantee full partition in the window for all rows. - return false - } - // precedingConfirmed and followingConfirmed indicate whether, for every row, - // all preceding and following, respectively, rows are always in the window. - precedingConfirmed := wfr.Frame.Bounds.StartBound.BoundType == UnboundedPreceding - followingConfirmed := wfr.Frame.Bounds.EndBound.BoundType == UnboundedFollowing - if wfr.Frame.Mode == ROWS || wfr.Frame.Mode == GROUPS { - // Every peer group in GROUPS modealways contains at least one row, so - // treating GROUPS as ROWS here is a subset of the cases when we should - // return true. - if wfr.Frame.Bounds.StartBound.BoundType == OffsetPreceding { - // Both ROWS and GROUPS have an offset of integer type, so this type - // conversion is safe. - startOffset := wfr.StartBoundOffset.(*DInt) - // The idea of this conditional is that to confirm that all preceding - // rows will always be in the window, we only need to look at the last - // row: if startOffset is at least as large as the number of rows in the - // partition before the last one, then it will be true for the first to - // last, second to last, etc. - precedingConfirmed = precedingConfirmed || *startOffset >= DInt(wfr.Rows.Len()-1) - } - if wfr.Frame.Bounds.EndBound.BoundType == OffsetFollowing { - // Both ROWS and GROUPS have an offset of integer type, so this type - // conversion is safe. - endOffset := wfr.EndBoundOffset.(*DInt) - // The idea of this conditional is that to confirm that all following - // rows will always be in the window, we only need to look at the first - // row: if endOffset is at least as large as the number of rows in the - // partition after the first one, then it will be true for the second, - // third, etc rows as well. - followingConfirmed = followingConfirmed || *endOffset >= DInt(wfr.Rows.Len()-1) - } - } - return precedingConfirmed && followingConfirmed -} - -// noFilter returns whether a filter is present. -func (wfr *WindowFrameRun) noFilter() bool { - return wfr.FilterColIdx == NoColumnIdx -} - -// isRowExcluded returns whether the row at index idx should be excluded from -// the window frame of the current row. -func (wfr *WindowFrameRun) isRowExcluded(idx int) (bool, error) { - if wfr.Frame.DefaultFrameExclusion() { - // By default, no rows are excluded. - return false, nil - } - switch wfr.Frame.Exclusion { - case ExcludeCurrentRow: - return idx == wfr.RowIdx, nil - case ExcludeGroup: - curRowFirstPeerIdx := wfr.PeerHelper.GetFirstPeerIdx(wfr.CurRowPeerGroupNum) - curRowPeerGroupRowCount := wfr.PeerHelper.GetRowCount(wfr.CurRowPeerGroupNum) - return curRowFirstPeerIdx <= idx && idx < curRowFirstPeerIdx+curRowPeerGroupRowCount, nil - case ExcludeTies: - curRowFirstPeerIdx := wfr.PeerHelper.GetFirstPeerIdx(wfr.CurRowPeerGroupNum) - curRowPeerGroupRowCount := wfr.PeerHelper.GetRowCount(wfr.CurRowPeerGroupNum) - return curRowFirstPeerIdx <= idx && idx < curRowFirstPeerIdx+curRowPeerGroupRowCount && idx != wfr.RowIdx, nil - default: - return false, errors.AssertionFailedf("unexpected WindowFrameExclusion") - } -} - -// IsRowSkipped returns whether a row at index idx is skipped from the window -// frame (it can either be filtered out according to the filter clause or -// excluded according to the frame exclusion clause) and any error if it -// occurs. -func (wfr *WindowFrameRun) IsRowSkipped(ctx context.Context, idx int) (bool, error) { - if !wfr.noFilter() { - row, err := wfr.Rows.GetRow(ctx, idx) - if err != nil { - return false, err - } - d, err := row.GetDatum(wfr.FilterColIdx) - if err != nil { - return false, err - } - if d != DBoolTrue { - // Row idx is filtered out from the window frame, so it is skipped. - return true, nil - } - } - // If a row is excluded from the window frame, it is skipped. - return wfr.isRowExcluded(idx) -} - -// WindowFunc performs a computation on each row using data from a provided *WindowFrameRun. -type WindowFunc interface { - // Compute computes the window function for the provided window frame, given the - // current state of WindowFunc. The method should be called sequentially for every - // row in a partition in turn with the desired ordering of the WindowFunc. This is - // because there is an implicit carried dependency between each row and all those - // that have come before it (like in an AggregateFunc). As such, this approach does - // not present any exploitable associativity/commutativity for optimization. - Compute(context.Context, *EvalContext, *WindowFrameRun) (Datum, error) - - // Reset resets the window function which allows for reusing it when - // computing over different partitions. - Reset(context.Context) - - // Close allows the window function to free any memory it requested during execution, - // such as during the execution of an aggregation like CONCAT_AGG or ARRAY_AGG. - Close(context.Context, *EvalContext) -} diff --git a/postgres/parser/sem/tree/window_funcs_util.go b/postgres/parser/sem/tree/window_funcs_util.go deleted file mode 100644 index c811b15c9d..0000000000 --- a/postgres/parser/sem/tree/window_funcs_util.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2018 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package tree - -import "github.com/dolthub/doltgresql/postgres/parser/ring" - -// PeerGroupChecker can check if a pair of row indices within a partition are -// in the same peer group. It also returns an error if it occurs while checking -// the peer groups. -type PeerGroupChecker interface { - InSameGroup(i, j int) (bool, error) -} - -// peerGroup contains information about a single peer group. -type peerGroup struct { - firstPeerIdx int - rowCount int -} - -// PeerGroupsIndicesHelper computes peer groups using the given -// PeerGroupChecker. In ROWS and RANGE modes, it processes one peer group at -// a time and stores information only about single peer group. In GROUPS mode, -// it's behavior depends on the frame bounds; in the worst case, it stores -// max(F, O) peer groups at the same time, where F is the maximum number of -// peer groups within the frame at any point and O is the maximum of two -// offsets if we have OFFSET_FOLLOWING type of bound (both F and O are -// upper-bounded by total number of peer groups). -type PeerGroupsIndicesHelper struct { - groups ring.Buffer // queue of peer groups - peerGrouper PeerGroupChecker - headPeerGroupNum int // number of the peer group at the head of the queue - allPeerGroupsSkipped bool // in GROUP mode, indicates whether all peer groups were skipped during Init - allRowsProcessed bool // indicates whether peer groups for all rows within partition have been already computed - unboundedFollowing int // index of the first row after all rows of the partition -} - -// Init computes all peer groups necessary to perform calculations of a window -// function over the first row of the partition. It returns any error if it -// occurs. -func (p *PeerGroupsIndicesHelper) Init(wfr *WindowFrameRun, peerGrouper PeerGroupChecker) error { - // We first reset the helper to reuse the same one for all partitions when - // computing a particular window function. - p.groups.Reset() - p.headPeerGroupNum = 0 - p.allPeerGroupsSkipped = false - p.allRowsProcessed = false - p.unboundedFollowing = wfr.unboundedFollowing() - - var group *peerGroup - p.peerGrouper = peerGrouper - startIdxOfFirstPeerGroupWithinFrame := 0 - if wfr.Frame != nil && wfr.Frame.Mode == GROUPS && wfr.Frame.Bounds.StartBound.BoundType == OffsetFollowing { - // In GROUPS mode with OFFSET_FOLLOWING as a start bound, 'peerGroupOffset' - // number of peer groups needs to be processed upfront before we get to - // peer groups that will be within a frame of the first row. - // If start bound is of type: - // - UNBOUNDED_PRECEDING - we don't use this helper at all - // - OFFSET_PRECEDING - no need to process any peer groups upfront - // - CURRENT_ROW - no need to process any peer groups upfront - // - OFFSET_FOLLOWING - processing is done here - // - UNBOUNDED_FOLLOWING - invalid as a start bound - // - // We also cannot simply discard information about these peer groups: even - // though they will never be within frames of any rows, we still might need - // information about them. For example, with frame as follows: - // GROUPS BETWEEN 1 FOLLOWING AND 3 FOLLOWING - // when processing the rows from zeroth peer group, we will need to know - // where zeroth peer group starts and how many rows it has, but the rows of - // zeroth group will never be in any frame. - peerGroupOffset := int(MustBeDInt(wfr.StartBoundOffset)) - group = &peerGroup{firstPeerIdx: 0, rowCount: 1} - for group.firstPeerIdx < wfr.PartitionSize() && p.groups.Len() < peerGroupOffset { - p.groups.AddLast(group) - for ; group.firstPeerIdx+group.rowCount < wfr.PartitionSize(); group.rowCount++ { - idx := group.firstPeerIdx + group.rowCount - if sameGroup, err := p.peerGrouper.InSameGroup(idx-1, idx); err != nil { - return err - } else if !sameGroup { - break - } - } - group = &peerGroup{firstPeerIdx: group.firstPeerIdx + group.rowCount, rowCount: 1} - } - - if group.firstPeerIdx == wfr.PartitionSize() { - // Frame starts after all peer groups of the partition. - p.allPeerGroupsSkipped = true - return nil - } - - startIdxOfFirstPeerGroupWithinFrame = group.firstPeerIdx - } - - // Compute the first peer group that is within the frame. - group = &peerGroup{firstPeerIdx: startIdxOfFirstPeerGroupWithinFrame, rowCount: 1} - p.groups.AddLast(group) - for ; group.firstPeerIdx+group.rowCount < wfr.PartitionSize(); group.rowCount++ { - idx := group.firstPeerIdx + group.rowCount - if sameGroup, err := p.peerGrouper.InSameGroup(idx-1, idx); err != nil { - return err - } else if !sameGroup { - break - } - } - if group.firstPeerIdx+group.rowCount == wfr.PartitionSize() { - p.allRowsProcessed = true - return nil - } - - if wfr.Frame != nil && wfr.Frame.Mode == GROUPS && wfr.Frame.Bounds.EndBound != nil && wfr.Frame.Bounds.EndBound.BoundType == OffsetFollowing { - // In GROUPS mode, 'peerGroupOffset' number of peer groups need to be - // processed upfront because they are within the frame of the first row. - // If end bound is of type: - // - UNBOUNDED_PRECEDING - invalid as an end bound - // - OFFSET_PRECEDING - no need to process any peer groups upfront - // - CURRENT_ROW - no need to process any more peer groups upfront - // - OFFSET_FOLLOWING - processing is done here - // - UNBOUNDED_FOLLOWING - we don't use this helper at all - peerGroupOffset := int(MustBeDInt(wfr.EndBoundOffset)) - group = &peerGroup{firstPeerIdx: group.firstPeerIdx + group.rowCount, rowCount: 1} - for group.firstPeerIdx < wfr.PartitionSize() && p.groups.Len() <= peerGroupOffset { - p.groups.AddLast(group) - for ; group.firstPeerIdx+group.rowCount < wfr.PartitionSize(); group.rowCount++ { - idx := group.firstPeerIdx + group.rowCount - if sameGroup, err := p.peerGrouper.InSameGroup(idx-1, idx); err != nil { - return err - } else if !sameGroup { - break - } - } - group = &peerGroup{firstPeerIdx: group.firstPeerIdx + group.rowCount, rowCount: 1} - } - if group.firstPeerIdx == wfr.PartitionSize() { - p.allRowsProcessed = true - } - } - return nil -} - -// Update should be called after a window function has been computed over all -// rows in wfr.CurRowPeerGroupNum peer group. If not all rows have been already -// processed, it computes the next peer group. It returns any error if it -// occurs. -func (p *PeerGroupsIndicesHelper) Update(wfr *WindowFrameRun) error { - if p.allPeerGroupsSkipped { - // No peer groups to process. - return nil - } - - // nextPeerGroupStartIdx is the index of the first row that we haven't - // computed peer group for. - lastPeerGroup := p.groups.GetLast().(*peerGroup) - nextPeerGroupStartIdx := lastPeerGroup.firstPeerIdx + lastPeerGroup.rowCount - - if (wfr.Frame == nil || wfr.Frame.Mode == ROWS || wfr.Frame.Mode == RANGE) || - (wfr.Frame.Bounds.StartBound.BoundType == OffsetPreceding && wfr.CurRowPeerGroupNum-p.headPeerGroupNum > int(MustBeDInt(wfr.StartBoundOffset)) || - wfr.Frame.Bounds.StartBound.BoundType == CurrentRow || - (wfr.Frame.Bounds.StartBound.BoundType == OffsetFollowing && p.headPeerGroupNum-wfr.CurRowPeerGroupNum > int(MustBeDInt(wfr.StartBoundOffset)))) { - // With default frame, ROWS or RANGE mode, we want to "discard" the only - // peer group that we're storing information about. In GROUPS mode, with - // start bound of type: - // - OFFSET_PRECEDING we want to start discarding the "earliest" peer group - // only when the number of current row's peer group differs from the - // number of the earliest one by more than offset - // - CURRENT_ROW we want to discard the earliest peer group - // - OFFSET_FOLLOWING we want to start discarding the "earliest" peer group - // only when the number of current row's peer group differs from the - // number of the earliest one by more than offset - p.groups.RemoveFirst() - p.headPeerGroupNum++ - } - - if p.allRowsProcessed { - // No more peer groups to process. - return nil - } - - // Compute the next peer group that is just entering the frame. - peerGroup := &peerGroup{firstPeerIdx: nextPeerGroupStartIdx, rowCount: 1} - p.groups.AddLast(peerGroup) - for ; peerGroup.firstPeerIdx+peerGroup.rowCount < wfr.PartitionSize(); peerGroup.rowCount++ { - idx := peerGroup.firstPeerIdx + peerGroup.rowCount - if sameGroup, err := p.peerGrouper.InSameGroup(idx-1, idx); err != nil { - return err - } else if !sameGroup { - break - } - } - if peerGroup.firstPeerIdx+peerGroup.rowCount == wfr.PartitionSize() { - p.allRowsProcessed = true - } - return nil -} - -// GetFirstPeerIdx returns index of the first peer within peer group of number -// peerGroupNum (counting from 0). -func (p *PeerGroupsIndicesHelper) GetFirstPeerIdx(peerGroupNum int) int { - posInBuffer := peerGroupNum - p.headPeerGroupNum - if posInBuffer < 0 || p.groups.Len() < posInBuffer { - panic("peerGroupNum out of bounds") - } - return p.groups.Get(posInBuffer).(*peerGroup).firstPeerIdx -} - -// GetRowCount returns the number of rows within peer group of number -// peerGroupNum (counting from 0). -func (p *PeerGroupsIndicesHelper) GetRowCount(peerGroupNum int) int { - posInBuffer := peerGroupNum - p.headPeerGroupNum - if posInBuffer < 0 || p.groups.Len() < posInBuffer { - panic("peerGroupNum out of bounds") - } - return p.groups.Get(posInBuffer).(*peerGroup).rowCount -} - -// GetLastPeerGroupNum returns the number of the last peer group in the queue. -func (p *PeerGroupsIndicesHelper) GetLastPeerGroupNum() int { - if p.groups.Len() == 0 { - panic("GetLastPeerGroupNum on empty RingBuffer") - } - return p.headPeerGroupNum + p.groups.Len() - 1 -} diff --git a/postgres/parser/sessiondata/search_path.go b/postgres/parser/sessiondata/search_path.go index 48e3f3d22d..18f6f96d19 100644 --- a/postgres/parser/sessiondata/search_path.go +++ b/postgres/parser/sessiondata/search_path.go @@ -37,16 +37,6 @@ const PgCatalogName = "pg_catalog" // PublicSchemaName is the name of the pg_catalog system schema. const PublicSchemaName = "public" -// InformationSchemaName is the name of the information_schema system schema. -const InformationSchemaName = "information_schema" - -// CRDBInternalSchemaName is the name of the crdb_internal system schema. -const CRDBInternalSchemaName = "crdb_internal" - -// PgSchemaPrefix is a prefix for Postgres system schemas. Users cannot -// create schemas with this prefix. -const PgSchemaPrefix = "pg_" - // PgTempSchemaName is the alias for temporary schemas across sessions. const PgTempSchemaName = "pg_temp" diff --git a/postgres/parser/sessiondata/sequence_state.go b/postgres/parser/sessiondata/sequence_state.go deleted file mode 100644 index 70622414af..0000000000 --- a/postgres/parser/sessiondata/sequence_state.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2018 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package sessiondata - -import ( - "github.com/dolthub/doltgresql/postgres/parser/pgcode" - "github.com/dolthub/doltgresql/postgres/parser/pgerror" - "github.com/dolthub/doltgresql/postgres/parser/syncutil" -) - -// SequenceState stores session-scoped state used by sequence builtins. -// -// All public methods of SequenceState are thread-safe, as the structure is -// meant to be shared by statements executing in parallel on a session. -type SequenceState struct { - mu struct { - syncutil.Mutex - // latestValues stores the last value obtained by nextval() in this session - // by descriptor id. - latestValues map[uint32]int64 - - // lastSequenceIncremented records the descriptor id of the last sequence - // nextval() was called on in this session. - lastSequenceIncremented uint32 - } -} - -// NewSequenceState creates a SequenceState. -func NewSequenceState() *SequenceState { - ss := SequenceState{} - ss.mu.latestValues = make(map[uint32]int64) - return &ss -} - -// NextVal ever called returns true if a sequence has ever been incremented on -// this session. -func (ss *SequenceState) nextValEverCalledLocked() bool { - return len(ss.mu.latestValues) > 0 -} - -// RecordValue records the latest manipulation of a sequence done by a session. -func (ss *SequenceState) RecordValue(seqID uint32, val int64) { - ss.mu.Lock() - ss.mu.lastSequenceIncremented = seqID - ss.mu.latestValues[seqID] = val - ss.mu.Unlock() -} - -// SetLastSequenceIncremented sets the id of the last incremented sequence. -// Usually this id is set through RecordValue(). -func (ss *SequenceState) SetLastSequenceIncremented(seqID uint32) { - ss.mu.Lock() - ss.mu.lastSequenceIncremented = seqID - ss.mu.Unlock() -} - -// GetLastValue returns the value most recently obtained by -// nextval() for the last sequence for which RecordLatestVal() was called. -func (ss *SequenceState) GetLastValue() (int64, error) { - ss.mu.Lock() - defer ss.mu.Unlock() - - if !ss.nextValEverCalledLocked() { - return 0, pgerror.New( - pgcode.ObjectNotInPrerequisiteState, "lastval is not yet defined in this session") - } - - return ss.mu.latestValues[ss.mu.lastSequenceIncremented], nil -} - -// GetLastValueByID returns the value most recently obtained by nextval() for -// the given sequence in this session. -// The bool retval is false if RecordLatestVal() was never called on the -// requested sequence. -func (ss *SequenceState) GetLastValueByID(seqID uint32) (int64, bool) { - ss.mu.Lock() - defer ss.mu.Unlock() - - val, ok := ss.mu.latestValues[seqID] - return val, ok -} - -// Export returns a copy of the SequenceState's state - the latestValues and -// lastSequenceIncremented. -// lastSequenceIncremented is only defined if latestValues is non-empty. -func (ss *SequenceState) Export() (map[uint32]int64, uint32) { - ss.mu.Lock() - defer ss.mu.Unlock() - res := make(map[uint32]int64, len(ss.mu.latestValues)) - for k, v := range ss.mu.latestValues { - res[k] = v - } - return res, ss.mu.lastSequenceIncremented -} diff --git a/postgres/parser/sessiondata/session_data.go b/postgres/parser/sessiondata/session_data.go deleted file mode 100644 index 1fb3546930..0000000000 --- a/postgres/parser/sessiondata/session_data.go +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2018 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package sessiondata - -import ( - "fmt" - "net" - "strings" - "time" - - "github.com/dolthub/doltgresql/postgres/parser/pgnotice" - - "github.com/dolthub/doltgresql/postgres/parser/lex" -) - -// SessionData contains session parameters. They are all user-configurable. -// A SQL Session changes fields in SessionData through sql.sessionDataMutator. -type SessionData struct { - // ApplicationName is the name of the application running the - // current session. This can be used for logging and per-application - // statistics. - ApplicationName string - // Database indicates the "current" database for the purpose of - // resolving names. See searchAndQualifyDatabase() for details. - Database string - // DefaultTxnPriority indicates the default priority of newly created - // transactions. - // NOTE: we'd prefer to use tree.UserPriority here, but doing so would - // introduce a package dependency cycle. - DefaultTxnPriority int - // DefaultReadOnly indicates the default read-only status of newly created - // transactions. - DefaultReadOnly bool - // DistSQLMode indicates whether to run queries using the distributed - // execution engine. - DistSQLMode DistSQLExecMode - // ExperimentalDistSQLPlanningMode indicates whether the experimental - // DistSQL planning driven by the optimizer is enabled. - ExperimentalDistSQLPlanningMode ExperimentalDistSQLPlanningMode - // PartiallyDistributedPlansDisabled indicates whether the partially - // distributed plans produced by distSQLSpecExecFactory are disabled. It - // should be set to 'true' only in tests that verify that the old and the - // new factories return exactly the same physical plans. - // TODO(yuzefovich): remove it when deleting old sql.execFactory. - PartiallyDistributedPlansDisabled bool - // OptimizerFKCascadesLimit is the maximum number of cascading operations that - // are run for a single query. - OptimizerFKCascadesLimit int - // OptimizerUseHistograms indicates whether we should use histograms for - // cardinality estimation in the optimizer. - OptimizerUseHistograms bool - // OptimizerUseMultiColStats indicates whether we should use multi-column - // statistics for cardinality estimation in the optimizer. - OptimizerUseMultiColStats bool - // SerialNormalizationMode indicates how to handle the SERIAL pseudo-type. - SerialNormalizationMode SerialNormalizationMode - // SearchPath is a list of namespaces to search builtins in. - SearchPath SearchPath - // TemporarySchemaID is the ID of the current session's temporary schema, - // if it exists. It is a descpb.ID, but cannot be stored as one due to - // packaging dependencies. - TemporarySchemaID uint32 - // StmtTimeout is the duration a query is permitted to run before it is - // canceled by the session. If set to 0, there is no timeout. - StmtTimeout time.Duration - // IdleInSessionTimeout is the duration a session is permitted to idle before - // the session is canceled. If set to 0, there is no timeout. - IdleInSessionTimeout time.Duration - // IdleInTransactionSessionTimeout is the duration a session is permitted to - // idle in a transaction before the session is canceled. - // If set to 0, there is no timeout. - IdleInTransactionSessionTimeout time.Duration - // User is the name of the user logged into the session. - User string - // SafeUpdates causes errors when the client - // sends syntax that may have unwanted side effects. - SafeUpdates bool - // PreferLookupJoinsForFKs causes foreign key operations to prefer lookup - // joins. - PreferLookupJoinsForFKs bool - // RemoteAddr is used to generate logging events. - RemoteAddr net.Addr - // ZigzagJoinEnabled indicates whether the optimizer should try and plan a - // zigzag join. - ZigzagJoinEnabled bool - // ReorderJoinsLimit indicates the number of joins at which the optimizer should - // stop attempting to reorder. - ReorderJoinsLimit int - // RequireExplicitPrimaryKeys indicates whether CREATE TABLE statements should - // error out if no primary key is provided. - RequireExplicitPrimaryKeys bool - // SequenceState gives access to the SQL sequences that have been manipulated - // by the session. - SequenceState *SequenceState - // DataConversion gives access to the data conversion configuration. - DataConversion DataConversionConfig - // VectorizeMode indicates which kinds of queries to use vectorized execution - // engine for. - VectorizeMode VectorizeExecMode - // VectorizeRowCountThreshold indicates the row count above which the - // vectorized execution engine will be used if possible. - VectorizeRowCountThreshold uint64 - // ForceSavepointRestart overrides the default SAVEPOINT behavior - // for compatibility with certain ORMs. When this flag is set, - // the savepoint name will no longer be compared against the magic - // identifier `cockroach_restart` in order use a restartable - // transaction. - ForceSavepointRestart bool - // DefaultIntSize specifies the size in bits or bytes (preferred) - // of how a "naked" INT type should be parsed. - DefaultIntSize int - // ResultsBufferSize specifies the size at which the pgwire results buffer - // will self-flush. - ResultsBufferSize int64 - // AllowPrepareAsOptPlan must be set to allow use of - // PREPARE name AS OPT PLAN '...' - AllowPrepareAsOptPlan bool - // SaveTablesPrefix indicates that a table should be created with the - // given prefix for the output of each subexpression in a query. If - // SaveTablesPrefix is empty, no tables are created. - SaveTablesPrefix string - // TempTablesEnabled indicates whether temporary tables can be created or not. - TempTablesEnabled bool - // HashShardedIndexesEnabled indicates whether hash sharded indexes can be created. - HashShardedIndexesEnabled bool - // DisallowFullTableScans indicates whether queries that plan full table scans - // should be rejected. - DisallowFullTableScans bool - // ImplicitSelectForUpdate is true if FOR UPDATE locking may be used during - // the row-fetch phase of mutation statements. - ImplicitSelectForUpdate bool - // InsertFastPath is true if the fast path for insert (with VALUES input) may - // be used. - InsertFastPath bool - // InterleavedJoins is true if interleaved joins may be used. - InterleavedJoins bool - // NoticeDisplaySeverity indicates the level of Severity to send notices for the given - // session. - NoticeDisplaySeverity pgnotice.DisplaySeverity - // AlterColumnTypeGeneralEnabled is true if ALTER TABLE ... ALTER COLUMN ... - // TYPE x may be used for general conversions requiring online schema change/ - AlterColumnTypeGeneralEnabled bool - - // SynchronousCommit is a dummy setting for the synchronous_commit var. - SynchronousCommit bool - // EnableSeqScan is a dummy setting for the enable_seqscan var. - EnableSeqScan bool -} - -// DataConversionConfig contains the parameters that influence -// the conversion between SQL data types and strings/byte arrays. -type DataConversionConfig struct { - // Location indicates the current time zone. - Location *time.Location - - // BytesEncodeFormat indicates how to encode byte arrays when converting - // to string. - BytesEncodeFormat lex.BytesEncodeFormat - - // ExtraFloatDigits indicates the number of digits beyond the - // standard number to use for float conversions. - // This must be set to a value between -15 and 3, inclusive. - ExtraFloatDigits int -} - -// GetFloatPrec computes a precision suitable for a call to -// strconv.FormatFloat() or for use with '%.*g' in a printf-like -// function. -func (c *DataConversionConfig) GetFloatPrec() int { - // The user-settable parameter ExtraFloatDigits indicates the number - // of digits to be used to format the float value. PostgreSQL - // combines this with %g. - // The formula is _DIG + extra_float_digits, - // where is either FLT (float4) or DBL (float8). - - // Also the value "3" in PostgreSQL is special and meant to mean - // "all the precision needed to reproduce the float exactly". The Go - // formatter uses the special value -1 for this and activates a - // separate path in the formatter. We compare >= 3 here - // just in case the value is not gated properly in the implementation - // of SET. - if c.ExtraFloatDigits >= 3 { - return -1 - } - - // CockroachDB only implements float8 at this time and Go does not - // expose DBL_DIG, so we use the standard literal constant for - // 64bit floats. - const StdDoubleDigits = 15 - - nDigits := StdDoubleDigits + c.ExtraFloatDigits - if nDigits < 1 { - // Ensure the value is clamped at 1: printf %g does not allow - // values lower than 1. PostgreSQL does this too. - nDigits = 1 - } - return nDigits -} - -// ExperimentalDistSQLPlanningMode controls if and when the opt-driven DistSQL -// planning is used to create physical plans. -type ExperimentalDistSQLPlanningMode int64 - -const ( - // ExperimentalDistSQLPlanningOff means that we always use the old path of - // going from opt.Expr to planNodes and then to processor specs. - ExperimentalDistSQLPlanningOff ExperimentalDistSQLPlanningMode = iota - // ExperimentalDistSQLPlanningOn means that we will attempt to use the new - // path for performing DistSQL planning in the optimizer, and if that - // doesn't succeed for some reason, we will fallback to the old path. - ExperimentalDistSQLPlanningOn - // ExperimentalDistSQLPlanningAlways means that we will only use the new path, - // and if it fails for any reason, the query will fail as well. - ExperimentalDistSQLPlanningAlways -) - -func (m ExperimentalDistSQLPlanningMode) String() string { - switch m { - case ExperimentalDistSQLPlanningOff: - return "off" - case ExperimentalDistSQLPlanningOn: - return "on" - case ExperimentalDistSQLPlanningAlways: - return "always" - default: - return fmt.Sprintf("invalid (%d)", m) - } -} - -// ExperimentalDistSQLPlanningModeFromString converts a string into a -// ExperimentalDistSQLPlanningMode. False is returned if the conversion was -// unsuccessful. -func ExperimentalDistSQLPlanningModeFromString(val string) (ExperimentalDistSQLPlanningMode, bool) { - var m ExperimentalDistSQLPlanningMode - switch strings.ToUpper(val) { - case "OFF": - m = ExperimentalDistSQLPlanningOff - case "ON": - m = ExperimentalDistSQLPlanningOn - case "ALWAYS": - m = ExperimentalDistSQLPlanningAlways - default: - return 0, false - } - return m, true -} - -// DistSQLExecMode controls if and when the Executor distributes queries. -// Since 2.1, we run everything through the DistSQL infrastructure, -// and these settings control whether to use a distributed plan, or use a plan -// that only involves local DistSQL processors. -type DistSQLExecMode int64 - -const ( - // DistSQLOff means that we never distribute queries. - DistSQLOff DistSQLExecMode = iota - // DistSQLAuto means that we automatically decide on a case-by-case basis if - // we distribute queries. - DistSQLAuto - // DistSQLOn means that we distribute queries that are supported. - DistSQLOn - // DistSQLAlways means that we only distribute; unsupported queries fail. - DistSQLAlways -) - -func (m DistSQLExecMode) String() string { - switch m { - case DistSQLOff: - return "off" - case DistSQLAuto: - return "auto" - case DistSQLOn: - return "on" - case DistSQLAlways: - return "always" - default: - return fmt.Sprintf("invalid (%d)", m) - } -} - -// DistSQLExecModeFromString converts a string into a DistSQLExecMode -func DistSQLExecModeFromString(val string) (_ DistSQLExecMode, ok bool) { - switch strings.ToUpper(val) { - case "OFF": - return DistSQLOff, true - case "AUTO": - return DistSQLAuto, true - case "ON": - return DistSQLOn, true - case "ALWAYS": - return DistSQLAlways, true - default: - return 0, false - } -} - -// VectorizeExecMode controls if an when the Executor executes queries using the -// columnar execution engine. -// WARNING: When adding a VectorizeExecMode, note that nodes at previous -// versions might interpret the integer value differently. To avoid this, only -// append to the list or bump the minimum required distsql version (maybe also -// take advantage of that to reorder the list as you see fit). -type VectorizeExecMode int64 - -const ( - // VectorizeOff means that columnar execution is disabled. - VectorizeOff VectorizeExecMode = iota - // Vectorize201Auto means that that any supported queries that use only - // streaming operators (i.e. those that do not require any buffering) will - // be run using the columnar execution. If any part of a query is not - // supported by the vectorized execution engine, the whole query will fall - // back to row execution. - // This is the default setting in 20.1. - Vectorize201Auto - // VectorizeOn means that any supported queries will be run using the - // columnar execution. - VectorizeOn - // VectorizeExperimentalAlways means that we attempt to vectorize all - // queries; unsupported queries will fail. Mostly used for testing. - VectorizeExperimentalAlways -) - -func (m VectorizeExecMode) String() string { - switch m { - case VectorizeOff: - return "off" - case Vectorize201Auto: - return "201auto" - case VectorizeOn: - return "on" - case VectorizeExperimentalAlways: - return "experimental_always" - default: - return fmt.Sprintf("invalid (%d)", m) - } -} - -// VectorizeExecModeFromString converts a string into a VectorizeExecMode. False -// is returned if the conversion was unsuccessful. -func VectorizeExecModeFromString(val string) (VectorizeExecMode, bool) { - var m VectorizeExecMode - switch strings.ToUpper(val) { - case "OFF": - m = VectorizeOff - case "201AUTO": - m = Vectorize201Auto - case "ON": - m = VectorizeOn - case "EXPERIMENTAL_ALWAYS": - m = VectorizeExperimentalAlways - default: - return 0, false - } - return m, true -} - -// SerialNormalizationMode controls if and when the Executor uses DistSQL. -type SerialNormalizationMode int64 - -const ( - // SerialUsesRowID means use INT NOT NULL DEFAULT unique_rowid(). - SerialUsesRowID SerialNormalizationMode = iota - // SerialUsesVirtualSequences means create a virtual sequence and - // use INT NOT NULL DEFAULT nextval(...). - SerialUsesVirtualSequences - // SerialUsesSQLSequences means create a regular SQL sequence and - // use INT NOT NULL DEFAULT nextval(...). - SerialUsesSQLSequences -) - -func (m SerialNormalizationMode) String() string { - switch m { - case SerialUsesRowID: - return "rowid" - case SerialUsesVirtualSequences: - return "virtual_sequence" - case SerialUsesSQLSequences: - return "sql_sequence" - default: - return fmt.Sprintf("invalid (%d)", m) - } -} - -// SerialNormalizationModeFromString converts a string into a SerialNormalizationMode -func SerialNormalizationModeFromString(val string) (_ SerialNormalizationMode, ok bool) { - switch strings.ToUpper(val) { - case "ROWID": - return SerialUsesRowID, true - case "VIRTUAL_SEQUENCE": - return SerialUsesVirtualSequences, true - case "SQL_SEQUENCE": - return SerialUsesSQLSequences, true - default: - return 0, false - } -} diff --git a/postgres/parser/syncutil/atomic.go b/postgres/parser/syncutil/atomic.go deleted file mode 100644 index 5f2bd3a686..0000000000 --- a/postgres/parser/syncutil/atomic.go +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2017 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package syncutil - -import ( - "math" - "sync/atomic" -) - -// AtomicFloat64 mimics the atomic types in the sync/atomic standard library, -// but for the float64 type. If you'd like to implement additional methods, -// consider checking out the expvar Float type for guidance: -// https://golang.org/src/expvar/expvar.go?s=2188:2222#L69 -type AtomicFloat64 uint64 - -// StoreFloat64 atomically stores a float64 value into the provided address. -func StoreFloat64(addr *AtomicFloat64, val float64) { - atomic.StoreUint64((*uint64)(addr), math.Float64bits(val)) -} - -// LoadFloat64 atomically loads tha float64 value from the provided address. -func LoadFloat64(addr *AtomicFloat64) (val float64) { - return math.Float64frombits(atomic.LoadUint64((*uint64)(addr))) -} - -// AtomicBool mimics an atomic boolean. -type AtomicBool uint32 - -// Set atomically sets the boolean. -func (b *AtomicBool) Set(v bool) { - s := uint32(0) - if v { - s = 1 - } - atomic.StoreUint32((*uint32)(b), s) -} - -// Get atomically gets the boolean. -func (b *AtomicBool) Get() bool { - return atomic.LoadUint32((*uint32)(b)) != 0 -} - -// Swap atomically swaps the value. -func (b *AtomicBool) Swap(v bool) bool { - wanted := uint32(0) - if v { - wanted = 1 - } - return atomic.SwapUint32((*uint32)(b), wanted) != 0 -} - -// AtomicString gives you atomic-style APIs for string. -type AtomicString struct { - s atomic.Value -} - -// Set atomically sets str as new value. -func (s *AtomicString) Set(val string) { - s.s.Store(val) -} - -// Get atomically returns the current value. -func (s *AtomicString) Get() string { - val := s.s.Load() - if val == nil { - return "" - } - return val.(string) -} diff --git a/postgres/parser/syncutil/int_map.go b/postgres/parser/syncutil/int_map.go deleted file mode 100644 index fcb2b6c2cc..0000000000 --- a/postgres/parser/syncutil/int_map.go +++ /dev/null @@ -1,413 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2019 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in licenses/BSD-golang.txt. - -// This code originated in Go's sync package. - -package syncutil - -import ( - "sync/atomic" - "unsafe" -) - -// IntMap is a concurrent map with amortized-constant-time loads, stores, and -// deletes. It is safe for multiple goroutines to call a Map's methods -// concurrently. -// -// It is optimized for use in concurrent loops with keys that are -// stable over time, and either few steady-state stores, or stores -// localized to one goroutine per key. -// -// For use cases that do not share these attributes, it will likely have -// comparable or worse performance and worse type safety than an ordinary -// map paired with a read-write mutex. -// -// Nil values are not supported; to use an IntMap as a set store a -// dummy non-nil pointer instead of nil. -// -// The zero Map is valid and empty. -// -// A Map must not be copied after first use. -type IntMap struct { - mu Mutex - - // read contains the portion of the map's contents that are safe for - // concurrent access (with or without mu held). - // - // The read field itself is always safe to load, but must only be stored with - // mu held. - // - // Entries stored in read may be updated concurrently without mu, but updating - // a previously-expunged entry requires that the entry be copied to the dirty - // map and unexpunged with mu held. - read unsafe.Pointer // *readOnly - - // dirty contains the portion of the map's contents that require mu to be - // held. To ensure that the dirty map can be promoted to the read map quickly, - // it also includes all of the non-expunged entries in the read map. - // - // Expunged entries are not stored in the dirty map. An expunged entry in the - // clean map must be unexpunged and added to the dirty map before a new value - // can be stored to it. - // - // If the dirty map is nil, the next write to the map will initialize it by - // making a shallow copy of the clean map, omitting stale entries. - dirty map[int64]*entry - - // misses counts the number of loads since the read map was last updated that - // needed to lock mu to determine whether the key was present. - // - // Once enough misses have occurred to cover the cost of copying the dirty - // map, the dirty map will be promoted to the read map (in the unamended - // state) and the next store to the map will make a new dirty copy. - misses int -} - -// readOnly is an immutable struct stored atomically in the Map.read field. -type readOnly struct { - m map[int64]*entry - amended bool // true if the dirty map contains some key not in m. -} - -// expunged is an arbitrary pointer that marks entries which have been deleted -// from the dirty map. -var expunged = unsafe.Pointer(new(int)) - -// An entry is a slot in the map corresponding to a particular key. -type entry struct { - // p points to the value stored for the entry. - // - // If p == nil, the entry has been deleted and m.dirty == nil. - // - // If p == expunged, the entry has been deleted, m.dirty != nil, and the entry - // is missing from m.dirty. - // - // Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty - // != nil, in m.dirty[key]. - // - // An entry can be deleted by atomic replacement with nil: when m.dirty is - // next created, it will atomically replace nil with expunged and leave - // m.dirty[key] unset. - // - // An entry's associated value can be updated by atomic replacement, provided - // p != expunged. If p == expunged, an entry's associated value can be updated - // only after first setting m.dirty[key] = e so that lookups using the dirty - // map find the entry. - p unsafe.Pointer -} - -func newEntry(r unsafe.Pointer) *entry { - return &entry{p: r} -} - -// Load returns the value stored in the map for a key, or nil if no -// value is present. -// The ok result indicates whether value was found in the map. -func (m *IntMap) Load(key int64) (value unsafe.Pointer, ok bool) { - read := m.getRead() - e, ok := read.m[key] - if !ok && read.amended { - m.mu.Lock() - // Avoid reporting a spurious miss if m.dirty got promoted while we were - // blocked on m.mu. (If further loads of the same key will not miss, it's - // not worth copying the dirty map for this key.) - read = m.getRead() - e, ok = read.m[key] - if !ok && read.amended { - e, ok = m.dirty[key] - // Regardless of whether the entry was present, record a miss: this key - // will take the slow path until the dirty map is promoted to the read - // map. - m.missLocked() - } - m.mu.Unlock() - } - if !ok { - return nil, false - } - return e.load() -} - -func (e *entry) load() (value unsafe.Pointer, ok bool) { - p := atomic.LoadPointer(&e.p) - if p == nil || p == expunged { - return nil, false - } - return p, true -} - -// Store sets the value for a key. -func (m *IntMap) Store(key int64, value unsafe.Pointer) { - read := m.getRead() - if e, ok := read.m[key]; ok && e.tryStore(value) { - return - } - - m.mu.Lock() - read = m.getRead() - if e, ok := read.m[key]; ok { - if e.unexpungeLocked() { - // The entry was previously expunged, which implies that there is a - // non-nil dirty map and this entry is not in it. - m.dirty[key] = e - } - e.storeLocked(value) - } else if e, ok := m.dirty[key]; ok { - e.storeLocked(value) - } else { - if !read.amended { - // We're adding the first new key to the dirty map. - // Make sure it is allocated and mark the read-only map as incomplete. - m.dirtyLocked() - atomic.StorePointer(&m.read, unsafe.Pointer(&readOnly{m: read.m, amended: true})) - } - m.dirty[key] = newEntry(value) - } - m.mu.Unlock() -} - -// tryStore stores a value if the entry has not been expunged. -// -// If the entry is expunged, tryStore returns false and leaves the entry -// unchanged. -func (e *entry) tryStore(r unsafe.Pointer) bool { - p := atomic.LoadPointer(&e.p) - if p == expunged { - return false - } - for { - if atomic.CompareAndSwapPointer(&e.p, p, r) { - return true - } - p = atomic.LoadPointer(&e.p) - if p == expunged { - return false - } - } -} - -// unexpungeLocked ensures that the entry is not marked as expunged. -// -// If the entry was previously expunged, it must be added to the dirty map -// before m.mu is unlocked. -func (e *entry) unexpungeLocked() (wasExpunged bool) { - return atomic.CompareAndSwapPointer(&e.p, expunged, nil) -} - -// storeLocked unconditionally stores a value to the entry. -// -// The entry must be known not to be expunged. -func (e *entry) storeLocked(r unsafe.Pointer) { - atomic.StorePointer(&e.p, r) -} - -// LoadOrStore returns the existing value for the key if present. -// Otherwise, it stores and returns the given value. -// The loaded result is true if the value was loaded, false if stored. -func (m *IntMap) LoadOrStore(key int64, value unsafe.Pointer) (actual unsafe.Pointer, loaded bool) { - // Avoid locking if it's a clean hit. - read := m.getRead() - if e, ok := read.m[key]; ok { - actual, loaded, ok = e.tryLoadOrStore(value) - if ok { - return actual, loaded - } - } - - m.mu.Lock() - read = m.getRead() - if e, ok := read.m[key]; ok { - if e.unexpungeLocked() { - m.dirty[key] = e - } - actual, loaded, _ = e.tryLoadOrStore(value) - } else if e, ok := m.dirty[key]; ok { - actual, loaded, _ = e.tryLoadOrStore(value) - m.missLocked() - } else { - if !read.amended { - // We're adding the first new key to the dirty map. - // Make sure it is allocated and mark the read-only map as incomplete. - m.dirtyLocked() - atomic.StorePointer(&m.read, unsafe.Pointer(&readOnly{m: read.m, amended: true})) - } - m.dirty[key] = newEntry(value) - actual, loaded = value, false - } - m.mu.Unlock() - - return actual, loaded -} - -// tryLoadOrStore atomically loads or stores a value if the entry is not -// expunged. -// -// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and -// returns with ok==false. -func (e *entry) tryLoadOrStore(r unsafe.Pointer) (actual unsafe.Pointer, loaded, ok bool) { - p := atomic.LoadPointer(&e.p) - if p == expunged { - return nil, false, false - } - if p != nil { - return p, true, true - } - - for { - if atomic.CompareAndSwapPointer(&e.p, nil, r) { - return r, false, true - } - p = atomic.LoadPointer(&e.p) - if p == expunged { - return nil, false, false - } - if p != nil { - return p, true, true - } - } -} - -// Delete deletes the value for a key. -func (m *IntMap) Delete(key int64) { - read := m.getRead() - e, ok := read.m[key] - if !ok && read.amended { - m.mu.Lock() - read = m.getRead() - e, ok = read.m[key] - if !ok && read.amended { - delete(m.dirty, key) - } - m.mu.Unlock() - } - if ok { - e.delete() - } -} - -func (e *entry) delete() (hadValue bool) { - for { - p := atomic.LoadPointer(&e.p) - if p == nil || p == expunged { - return false - } - if atomic.CompareAndSwapPointer(&e.p, p, nil) { - return true - } - } -} - -// Range calls f sequentially for each key and value present in the map. -// If f returns false, range stops the iteration. -// -// Range does not necessarily correspond to any consistent snapshot of the Map's -// contents: no key will be visited more than once, but if the value for any key -// is stored or deleted concurrently, Range may reflect any mapping for that key -// from any point during the Range call. -// -// Range may be O(N) with the number of elements in the map even if f returns -// false after a constant number of calls. -func (m *IntMap) Range(f func(key int64, value unsafe.Pointer) bool) { - // We need to be able to iterate over all of the keys that were already - // present at the start of the call to Range. - // If read.amended is false, then read.m satisfies that property without - // requiring us to hold m.mu for a long time. - read := m.getRead() - if read.amended { - // m.dirty contains keys not in read.m. Fortunately, Range is already O(N) - // (assuming the caller does not break out early), so a call to Range - // amortizes an entire copy of the map: we can promote the dirty copy - // immediately! - m.mu.Lock() - read = m.getRead() - if read.amended { - // Don't let read escape directly, otherwise it will allocate even - // when read.amended is false. Instead, constrain the allocation to - // just this branch. - newRead := &readOnly{m: m.dirty} - atomic.StorePointer(&m.read, unsafe.Pointer(newRead)) - read = *newRead - m.dirty = nil - m.misses = 0 - } - m.mu.Unlock() - } - - for k, e := range read.m { - v, ok := e.load() - if !ok { - continue - } - if !f(k, v) { - break - } - } -} - -func (m *IntMap) missLocked() { - m.misses++ - if m.misses < len(m.dirty) { - return - } - atomic.StorePointer(&m.read, unsafe.Pointer(&readOnly{m: m.dirty})) - m.dirty = nil - m.misses = 0 -} - -func (m *IntMap) dirtyLocked() { - if m.dirty != nil { - return - } - - read := m.getRead() - m.dirty = make(map[int64]*entry, len(read.m)) - for k, e := range read.m { - if !e.tryExpungeLocked() { - m.dirty[k] = e - } - } -} - -func (m *IntMap) getRead() readOnly { - read := (*readOnly)(atomic.LoadPointer(&m.read)) - if read == nil { - return readOnly{} - } - return *read -} - -func (e *entry) tryExpungeLocked() (isExpunged bool) { - p := atomic.LoadPointer(&e.p) - for p == nil { - if atomic.CompareAndSwapPointer(&e.p, nil, expunged) { - return true - } - p = atomic.LoadPointer(&e.p) - } - return p == expunged -} diff --git a/postgres/parser/syncutil/mutex_deadlock.go b/postgres/parser/syncutil/mutex_deadlock.go deleted file mode 100644 index 93af5d6a63..0000000000 --- a/postgres/parser/syncutil/mutex_deadlock.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -//go:build deadlock -// +build deadlock - -package syncutil - -import ( - "time" - - deadlock "github.com/sasha-s/go-deadlock" -) - -func init() { - deadlock.Opts.DeadlockTimeout = 5 * time.Minute -} - -// A Mutex is a mutual exclusion lock. -type Mutex struct { - deadlock.Mutex -} - -// AssertHeld is a no-op for deadlock mutexes. -func (m *Mutex) AssertHeld() { -} - -// An RWMutex is a reader/writer mutual exclusion lock. -type RWMutex struct { - deadlock.RWMutex -} - -// AssertHeld is a no-op for deadlock mutexes. -func (rw *RWMutex) AssertHeld() { -} - -// AssertRHeld is a no-op for deadlock mutexes. -func (rw *RWMutex) AssertRHeld() { -} diff --git a/postgres/parser/syncutil/mutex_sync.go b/postgres/parser/syncutil/mutex_sync.go deleted file mode 100644 index dbc2d8f617..0000000000 --- a/postgres/parser/syncutil/mutex_sync.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -//go:build !deadlock && !race -// +build !deadlock,!race - -package syncutil - -import "sync" - -// A Mutex is a mutual exclusion lock. -type Mutex struct { - sync.Mutex -} - -// AssertHeld may panic if the mutex is not locked (but it is not required to -// do so). Functions which require that their callers hold a particular lock -// may use this to enforce this requirement more directly than relying on the -// race detector. -// -// Note that we do not require the lock to be held by any particular thread, -// just that some thread holds the lock. This is both more efficient and allows -// for rare cases where a mutex is locked in one thread and used in another. -func (m *Mutex) AssertHeld() { -} - -// An RWMutex is a reader/writer mutual exclusion lock. -type RWMutex struct { - sync.RWMutex -} - -// AssertHeld may panic if the mutex is not locked for writing (but it is not -// required to do so). Functions which require that their callers hold a -// particular lock may use this to enforce this requirement more directly than -// relying on the race detector. -// -// Note that we do not require the exclusive lock to be held by any particular -// thread, just that some thread holds the lock. This is both more efficient -// and allows for rare cases where a mutex is locked in one thread and used in -// another. -func (rw *RWMutex) AssertHeld() { -} - -// AssertRHeld may panic if the mutex is not locked for reading (but it is not -// required to do so). If the mutex is locked for writing, it is also considered -// to be locked for reading. Functions which require that their callers hold a -// particular lock may use this to enforce this requirement more directly than -// relying on the race detector. -// -// Note that we do not require the shared lock to be held by any particular -// thread, just that some thread holds the lock. This is both more efficient -// and allows for rare cases where a mutex is locked in one thread and used in -// another. -func (rw *RWMutex) AssertRHeld() { -} diff --git a/postgres/parser/syncutil/mutex_sync_race.go b/postgres/parser/syncutil/mutex_sync_race.go deleted file mode 100644 index be4176d17c..0000000000 --- a/postgres/parser/syncutil/mutex_sync_race.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2016 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -//go:build !deadlock && race -// +build !deadlock,race - -package syncutil - -import ( - "sync" - "sync/atomic" -) - -// A Mutex is a mutual exclusion lock. -type Mutex struct { - mu sync.Mutex - wLocked int32 // updated atomically -} - -// Lock locks m. -func (m *Mutex) Lock() { - m.mu.Lock() - atomic.StoreInt32(&m.wLocked, 1) -} - -// Unlock unlocks m. -func (m *Mutex) Unlock() { - atomic.StoreInt32(&m.wLocked, 0) - m.mu.Unlock() -} - -// AssertHeld may panic if the mutex is not locked (but it is not required to -// do so). Functions which require that their callers hold a particular lock -// may use this to enforce this requirement more directly than relying on the -// race detector. -// -// Note that we do not require the lock to be held by any particular thread, -// just that some thread holds the lock. This is both more efficient and allows -// for rare cases where a mutex is locked in one thread and used in another. -func (m *Mutex) AssertHeld() { - if atomic.LoadInt32(&m.wLocked) == 0 { - panic("mutex is not write locked") - } -} - -// An RWMutex is a reader/writer mutual exclusion lock. -type RWMutex struct { - sync.RWMutex - wLocked int32 // updated atomically - rLocked int32 // updated atomically -} - -// Lock locks rw for writing. -func (rw *RWMutex) Lock() { - rw.RWMutex.Lock() - atomic.StoreInt32(&rw.wLocked, 1) -} - -// Unlock unlocks rw for writing. -func (rw *RWMutex) Unlock() { - atomic.StoreInt32(&rw.wLocked, 0) - rw.RWMutex.Unlock() -} - -// RLock locks m for reading. -func (rw *RWMutex) RLock() { - rw.RWMutex.RLock() - atomic.AddInt32(&rw.rLocked, 1) -} - -// RUnlock undoes a single RLock call. -func (rw *RWMutex) RUnlock() { - atomic.AddInt32(&rw.rLocked, -1) - rw.RWMutex.RUnlock() -} - -// RLocker returns a Locker interface that implements -// the Lock and Unlock methods by calling rw.RLock and rw.RUnlock. -func (rw *RWMutex) RLocker() sync.Locker { - return (*rlocker)(rw) -} - -type rlocker RWMutex - -func (r *rlocker) Lock() { (*RWMutex)(r).RLock() } -func (r *rlocker) Unlock() { (*RWMutex)(r).RUnlock() } - -// AssertHeld may panic if the mutex is not locked for writing (but it is not -// required to do so). Functions which require that their callers hold a -// particular lock may use this to enforce this requirement more directly than -// relying on the race detector. -// -// Note that we do not require the exclusive lock to be held by any particular -// thread, just that some thread holds the lock. This is both more efficient -// and allows for rare cases where a mutex is locked in one thread and used in -// another. -func (rw *RWMutex) AssertHeld() { - if atomic.LoadInt32(&rw.wLocked) == 0 { - panic("mutex is not write locked") - } -} - -// AssertRHeld may panic if the mutex is not locked for reading (but it is not -// required to do so). If the mutex is locked for writing, it is also considered -// to be locked for reading. Functions which require that their callers hold a -// particular lock may use this to enforce this requirement more directly than -// relying on the race detector. -// -// Note that we do not require the shared lock to be held by any particular -// thread, just that some thread holds the lock. This is both more efficient -// and allows for rare cases where a mutex is locked in one thread and used in -// another. -func (rw *RWMutex) AssertRHeld() { - if atomic.LoadInt32(&rw.wLocked) == 0 && atomic.LoadInt32(&rw.rLocked) == 0 { - panic("mutex is not read locked") - } -} diff --git a/postgres/parser/syncutil/singleflight/singleflight.go b/postgres/parser/syncutil/singleflight/singleflight.go deleted file mode 100644 index 4094def34e..0000000000 --- a/postgres/parser/syncutil/singleflight/singleflight.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2019 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in licenses/BSD-golang.txt. - -// This code originated in Go's internal/singleflight package. - -// Package singleflight provides a duplicate function call suppression -// mechanism. -package singleflight - -import ( - "sync" - - "github.com/dolthub/doltgresql/postgres/parser/syncutil" -) - -// call is an in-flight or completed singleflight.Do call -type call struct { - wg sync.WaitGroup - - // These fields are written once before the WaitGroup is done - // and are only read after the WaitGroup is done. - val interface{} - err error - - // These fields are read and written with the singleflight - // mutex held before the WaitGroup is done, and are read but - // not written after the WaitGroup is done. - dups int - chans []chan<- Result -} - -// Group represents a class of work and forms a namespace in -// which units of work can be executed with duplicate suppression. -type Group struct { - mu syncutil.Mutex // protects m - m map[string]*call // lazily initialized -} - -// Result holds the results of Do, so they can be passed -// on a channel. -type Result struct { - Val interface{} - Err error - Shared bool -} - -// Do executes and returns the results of the given function, making -// sure that only one execution is in-flight for a given key at a -// time. If a duplicate comes in, the duplicate caller waits for the -// original to complete and receives the same results. -// The return value shared indicates whether v was given to multiple callers. -func (g *Group) Do( - key string, fn func() (interface{}, error), -) (v interface{}, shared bool, err error) { - g.mu.Lock() - if g.m == nil { - g.m = make(map[string]*call) - } - if c, ok := g.m[key]; ok { - c.dups++ - g.mu.Unlock() - c.wg.Wait() - return c.val, true, c.err - } - c := new(call) - c.wg.Add(1) - g.m[key] = c - g.mu.Unlock() - - g.doCall(c, key, fn) - return c.val, c.dups > 0, c.err -} - -// DoChan is like Do but returns a channel that will receive the results when -// they are ready. The method also returns a boolean specifying whether the -// caller's fn function will be called or not. This return value lets callers -// identify a unique "leader" for a flight. -// -// NOTE: DoChan makes it possible to initiate or join a flight while holding a -// lock without holding it for the duration of the flight. A common usage -// pattern is: -// 1. Check some datastructure to see if it contains the value you're looking -// for. -// 2. If it doesn't, initiate or join a flight to produce it. -// -// Step one is expected to be done while holding a lock. Modifying the -// datastructure in the callback is expected to need to take the same lock. Once -// a caller proceeds to step two, it likely wants to keep the lock until -// DoChan() returned a channel, in order to ensure that a flight is only started -// before any modifications to the datastructure occurred (relative to the state -// observed in step one). Were the lock to be released before calling DoChan(), -// a previous flight might modify the datastructure before our flight began. -func (g *Group) DoChan(key string, fn func() (interface{}, error)) (<-chan Result, bool) { - ch := make(chan Result, 1) - g.mu.Lock() - if g.m == nil { - g.m = make(map[string]*call) - } - if c, ok := g.m[key]; ok { - c.dups++ - c.chans = append(c.chans, ch) - g.mu.Unlock() - return ch, false - } - c := &call{chans: []chan<- Result{ch}} - c.wg.Add(1) - g.m[key] = c - g.mu.Unlock() - - go g.doCall(c, key, fn) - - return ch, true -} - -// doCall handles the single call for a key. -func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { - c.val, c.err = fn() - c.wg.Done() - - g.mu.Lock() - delete(g.m, key) - for _, ch := range c.chans { - ch <- Result{c.val, c.err, c.dups > 0} - } - g.mu.Unlock() -} - -var _ = (*Group).Forget - -// Forget tells the singleflight to forget about a key. Future calls -// to Do for this key will call the function rather than waiting for -// an earlier call to complete. -func (g *Group) Forget(key string) { - g.mu.Lock() - delete(g.m, key) - g.mu.Unlock() -} - -// NumCalls returns the number of in-flight calls for a given key. -func (g *Group) NumCalls(key string) int { - g.mu.Lock() - defer g.mu.Unlock() - if c, ok := g.m[key]; ok { - return c.dups + 1 - } - return 0 -} diff --git a/postgres/parser/unique/unique.go b/postgres/parser/unique/unique.go deleted file mode 100644 index cfe7ba7a5e..0000000000 --- a/postgres/parser/unique/unique.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2020 The Cockroach Authors. -// -// Use of this software is governed by the Business Source License -// included in the file licenses/BSL.txt. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0, included in the file -// licenses/APL.txt. - -package unique - -import ( - "bytes" - "reflect" - "sort" -) - -// UniquifyByteSlices takes as input a slice of slices of bytes, and -// deduplicates them using a sort and unique. The output will not contain any -// duplicates but it will be sorted. -func UniquifyByteSlices(slices [][]byte) [][]byte { - if len(slices) == 0 { - return slices - } - // First sort: - sort.Slice(slices, func(i int, j int) bool { - return bytes.Compare(slices[i], slices[j]) < 0 - }) - // Then distinct: (wouldn't it be nice if Go had generics?) - lastUniqueIdx := 0 - for i := 1; i < len(slices); i++ { - if !bytes.Equal(slices[i], slices[lastUniqueIdx]) { - // We found a unique entry, at index i. The last unique entry in the array - // was at lastUniqueIdx, so set the entry after that one to our new unique - // entry, and bump lastUniqueIdx for the next loop iteration. - lastUniqueIdx++ - slices[lastUniqueIdx] = slices[i] - } - } - slices = slices[:lastUniqueIdx+1] - return slices -} - -// UniquifyAcrossSlices removes elements from both slices that are duplicated -// across both of the slices. For example, inputs [1,2,3], [2,3,4] would remove -// 2 and 3 from both lists. -// It assumes that both slices are pre-sorted using the same comparison metric -// as cmpFunc provides, and also already free of duplicates internally. It -// returns the slices, which will have also been sorted as a side effect. -// cmpFunc compares the lth index of left to the rth index of right. It must -// return less than 0 if the left element is less than the right element, 0 if -// equal, and greater than 0 otherwise. -// setLeft sets the ith index of left to the jth index of left. -// setRight sets the ith index of right to the jth index of right. -// The function returns the new lengths of both input slices, whose elements -// will have been mutated, but whose lengths must be set the new lengths by -// the caller. -func UniquifyAcrossSlices( - left interface{}, - right interface{}, - cmpFunc func(l, r int) int, - setLeft func(i, j int), - setRight func(i, j int), -) (leftLen, rightLen int) { - leftSlice := reflect.ValueOf(left) - rightSlice := reflect.ValueOf(right) - - lLen := leftSlice.Len() - rLen := rightSlice.Len() - - var lIn, lOut int - var rIn, rOut int - - // Remove entries that are duplicated across both entry lists. - // This loop walks through both lists using a merge strategy. Two pointers per - // list are maintained. One is the "input pointer", which is always the ith - // element of the input list. One is the "output pointer", which is the index - // after the most recent unique element in the list. Every time we bump the - // input pointer, we also set the element at the output pointer to that at - // the input pointer, so we don't have to use extra space - we're - // deduplicating in-place. - for rIn < rLen || lIn < lLen { - var cmp int - if lIn == lLen { - cmp = 1 - } else if rIn == rLen { - cmp = -1 - } else { - cmp = cmpFunc(lIn, rIn) - } - if cmp < 0 { - setLeft(lOut, lIn) - lIn++ - lOut++ - } else if cmp > 0 { - setRight(rOut, rIn) - rIn++ - rOut++ - } else { - // Elements are identical - we want to remove them from the list. So - // we increment our input indices without touching our output indices. - // Next time through the loop, we'll shift the next element back to - // the last output index which is now lagging behind the input index. - lIn++ - rIn++ - } - } - return lOut, rOut -} diff --git a/postgres/parser/uuid/generator.go b/postgres/parser/uuid/generator.go index 8a66557df9..c6a8abd881 100644 --- a/postgres/parser/uuid/generator.go +++ b/postgres/parser/uuid/generator.go @@ -43,8 +43,6 @@ import ( "time" "github.com/cockroachdb/errors" - - "github.com/dolthub/doltgresql/postgres/parser/syncutil" ) // Difference in 100-nanosecond intervals between @@ -102,7 +100,7 @@ type Generator interface { type Gen struct { clockSequenceOnce sync.Once hardwareAddrOnce sync.Once - storageMutex syncutil.Mutex + storageMutex sync.Mutex rand io.Reader