From 1984082ccf35e19f9da1e06626a4e5029a5c1854 Mon Sep 17 00:00:00 2001 From: ptaylor Date: Thu, 15 Sep 2022 22:31:10 -0700 Subject: [PATCH] fix compile error --- modules/sql/src/cluster/local.ts | 6 ++-- modules/sql/src/context.ts | 2 +- modules/sql/src/graph.ts | 51 +++++++++++++++++++-------- modules/sql/test/sql-context-tests.ts | 14 ++++++++ 4 files changed, 54 insertions(+), 19 deletions(-) diff --git a/modules/sql/src/cluster/local.ts b/modules/sql/src/cluster/local.ts index 4645c173a..578de7f45 100644 --- a/modules/sql/src/cluster/local.ts +++ b/modules/sql/src/cluster/local.ts @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. +// Copyright (c) 2021-2022, NVIDIA CORPORATION. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -53,7 +53,5 @@ export class LocalSQLWorker implements Worker { public async dropTable(name: string) { await this.context.dropTable(name); } - public async sql(query: string, token: number) { - return await (await this.context.sql(query, token)).result(); - } + public async sql(query: string, token: number) { return await this.context.sql(query, token); } } diff --git a/modules/sql/src/context.ts b/modules/sql/src/context.ts index 44694db1e..57c68a042 100644 --- a/modules/sql/src/context.ts +++ b/modules/sql/src/context.ts @@ -263,7 +263,7 @@ export class SQLContext { * const sqlContext = new SQLContext(); * sqlContext.createTable('test_table', df); * - * sqlContext.sql('SELECT a FROM test_table').result(); // [1, 2, 3] + * await sqlContext.sql('SELECT a FROM test_table'); // [1, 2, 3] * ``` */ public sql(query: string, ctxToken: number = Math.random() * Number.MAX_SAFE_INTEGER | 0) { diff --git a/modules/sql/src/graph.ts b/modules/sql/src/graph.ts index e08569ad5..7e602d904 100644 --- a/modules/sql/src/graph.ts +++ b/modules/sql/src/graph.ts @@ -18,27 +18,50 @@ import {DataFrame, Table} from '@rapidsai/cudf'; let nonce = Math.random() * 1e3 | 0; -export class ExecutionGraph { +export class ExecutionGraph implements Promise { constructor(private _graph?: import('./rapidsai_sql').ExecutionGraph) {} - start(): void { this._graph?.start(); } + get[Symbol.toStringTag]() { return 'ExecutionGraph'; } - then() { return this.result(); } + then( + onfulfilled?: ((value: DataFrame[]) => TResult1 | PromiseLike)|undefined|null, + onrejected?: ((reason: any) => TResult2 | PromiseLike)|undefined| + null): Promise { + return this.result().then(onfulfilled, onrejected); + } - async result() { - const {names, tables} = - this._graph ? (await this._graph.result()) : {names: [], tables: [new Table({})]}; - const results: DataFrame[] = []; - tables.forEach((table: Table) => { - results.push(new DataFrame( - names.reduce((cols, name, i) => ({...cols, [name]: table.getColumnByIndex(i)}), {}))); - }); + catch(onrejected?: ((reason: any) => TResult | PromiseLike)|undefined| + null): Promise { + return this.result().catch(onrejected); + } + + finally(onfinally?: (() => void)|undefined|null): Promise { + return this.result().finally(onfinally); + } + + private _result: Promise|undefined; + + start() { this._graph?.start(); } + + result() { + if (!this._result) { + this._result = (async () => { + const {names, tables} = + this._graph ? (await this._graph.result()) : {names: [], tables: [new Table({})]}; + const results: DataFrame[] = []; + tables.forEach((table: Table) => { + results.push(new DataFrame( + names.reduce((cols, name, i) => ({...cols, [name]: table.getColumnByIndex(i)}), {}))); + }); - return results; + return results; + })(); + } + return this._result; } - async sendTo(id: number) { - return await this.result().then((dfs) => { + sendTo(id: number) { + return this.then((dfs) => { const {_graph} = this; const inFlightTables: Record = {}; if (_graph) { diff --git a/modules/sql/test/sql-context-tests.ts b/modules/sql/test/sql-context-tests.ts index 93ac9f362..88020f959 100644 --- a/modules/sql/test/sql-context-tests.ts +++ b/modules/sql/test/sql-context-tests.ts @@ -1,3 +1,17 @@ +// Copyright (c) 2021-2022, NVIDIA CORPORATION. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + import {DataFrame, Float64, Series, Utf8String} from '@rapidsai/cudf'; import {SQLContext} from '@rapidsai/sql';