Skip to content

Commit

Permalink
Add huggingface extension
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewmturner committed Jan 15, 2025
1 parent 53c70ed commit 68dc9ff
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 1 deletion.
48 changes: 47 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ metrics = {version = "0.24.0", optional = true }
metrics-exporter-prometheus = {version = "0.16.0", optional = true }
num_cpus = "1.16.0"
object_store = { version = "0.11.0", features = ["aws"], optional = true }
opendal = { version = "0.51", features = ["services-huggingface"], optional = true }
object_store_opendal = { version = "0.49", optional = true}
parking_lot = "0.12.3"
parquet = "53.0.0"
pin-project-lite = {version = "0.2.14" }
Expand Down Expand Up @@ -76,6 +78,7 @@ s3 = ["object_store/aws", "url"]
functions-json = ["dep:datafusion-functions-json"]
functions-parquet = ["dep:datafusion-functions-parquet"]
metrics = ["dep:metrics", "dep:metrics-exporter-prometheus"]
huggingface = ["opendal", "object_store_opendal", "url"]

[[bin]]
name = "dft"
Expand Down
12 changes: 12 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,22 @@ impl S3Config {
}
}

#[cfg(feature = "huggingface")]
#[derive(Clone, Debug, Deserialize)]
pub struct HuggingFaceConfig {
pub repo_type: Option<String>,
pub repo_id: Option<String>,
pub revision: Option<String>,
pub root: Option<String>,
pub token: Option<String>,
}

#[derive(Clone, Debug, Deserialize)]
pub struct ObjectStoreConfig {
#[cfg(feature = "s3")]
pub s3: Option<Vec<S3Config>>,
#[cfg(feature = "huggingface")]
pub huggingface: Option<Vec<HuggingFaceConfig>>,
}

#[derive(Clone, Debug, Deserialize)]
Expand Down
99 changes: 99 additions & 0 deletions src/extensions/huggingface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Huggingface Integration: [HuggingFaceExtension]
use crate::config::ExecutionConfig;
use crate::extensions::{DftSessionStateBuilder, Extension};
use log::info;
use std::sync::Arc;

use opendal::{services::Huggingface, Builder, Operator};
use url::Url;

#[derive(Debug, Default)]
pub struct HuggingFaceExtension {}

impl HuggingFaceExtension {
pub fn new() -> Self {
info!("HuggingFace!");
Self {}
}
}

#[async_trait::async_trait]
impl Extension for HuggingFaceExtension {
async fn register(
&self,
config: ExecutionConfig,
builder: &mut DftSessionStateBuilder,
) -> datafusion_common::Result<()> {
let Some(object_store_config) = &config.object_store else {
return Ok(());
};

let Some(huggingface_configs) = &object_store_config.huggingface else {
return Ok(());
};

info!("Huggingface configs exists");
for huggingface_config in huggingface_configs {
// I'm not that famliar with Huggingface so I'm not sure what permutations of config
// values are supposed to work.

let mut base_url = String::from("https://huggingface.co/");
let mut url_parts = vec!["https://huggingface.co"];
let mut hf_builder = Huggingface::default();
if let Some(repo_type) = &huggingface_config.repo_type {
hf_builder = hf_builder.repo_type(repo_type);
url_parts.push(repo_type)
};
if let Some(repo_id) = &huggingface_config.repo_id {
hf_builder = hf_builder.repo_id(repo_id);
url_parts.push(repo_id);
};
if let Some(revision) = &huggingface_config.revision {
hf_builder = hf_builder.revision(revision);
url_parts.push("tree");
url_parts.push(revision);
};
if let Some(root) = &huggingface_config.root {
hf_builder = hf_builder.root(root);
};
if let Some(token) = &huggingface_config.token {
hf_builder = hf_builder.repo_id(token);
};

let operator = Operator::new(hf_builder)
.map_err(|e| {
datafusion_common::error::DataFusionError::External(e.to_string().into())
})?
.finish();

let store = object_store_opendal::OpendalStore::new(operator);
let url = Url::parse(url_parts.join("/").as_str()).map_err(|e| {
datafusion_common::error::DataFusionError::External(e.to_string().into())
})?;
info!("Registering store for huggingface url: {url}");
builder
.runtime_env()
.register_object_store(&url, Arc::new(store));
}

Ok(())
}
}
4 changes: 4 additions & 0 deletions src/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ mod builder;
mod deltalake;
#[cfg(feature = "hudi")]
mod hudi;
#[cfg(feature = "huggingface")]
mod huggingface;
#[cfg(feature = "iceberg")]
mod iceberg;
#[cfg(feature = "s3")]
Expand Down Expand Up @@ -62,5 +64,7 @@ pub fn enabled_extensions() -> Vec<Arc<dyn Extension>> {
Arc::new(hudi::HudiExtension::new()),
#[cfg(feature = "iceberg")]
Arc::new(iceberg::IcebergExtension::new()),
#[cfg(feature = "huggingface")]
Arc::new(huggingface::HuggingFaceExtension::new()),
]
}

0 comments on commit 68dc9ff

Please sign in to comment.