-
-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a MultiModal example Task #292
Comments
I also got sent some code for the preprocessing (not sure what we need from there but I am putting it here in case it is useful).
preprocess-1.R - R
library(dplyr)
library(mgcv)
df <- read.csv2("data/ISIC_2020_Training_GroundTruth_v2.csv", sep=",")
### remove empty
keep_mask <- (df$sex != "") & (df$anatom_site_general_challenge != "") & !is.na(df$age_approx)
cat("removing", sum(!keep_mask), "rows with empty columns\n")
df <- df[keep_mask,]
### encode
df$sex <- factor(df$sex)
df$diagnosis <- factor(df$diagnosis)
df$site <- factor(df$anatom_site_general_challenge)
df$benign_malignant <- factor(df$benign_malignant)
df$patient_id <- factor(df$patient_id)
pats <- df %>% group_by(patient_id) %>% summarise(n=n()) %>% filter(n>=4)
df <- df[df$patient_id %in% pats$patient_id,]
cat("kept", nrow(df), "lesions from patients with at least four\n")
pats <- pats[sample(1:nrow(pats)),]
test_pats <- pats[1:170, "patient_id"]
tune_pats <- pats[171:340, "patient_id"]
train_pats <- pats[341:nrow(pats), "patient_id"]
test_df <- df %>% filter(patient_id %in% test_pats$patient_id)
tune_df <- df %>% filter(patient_id %in% tune_pats$patient_id)
train_df <- df %>% filter(patient_id %in% train_pats$patient_id)
cat("got", nrow(test_df), "patients for test\n")
cat("got", nrow(tune_df), "patients for tuning\n")
cat("got", nrow(train_df), "patients for training\n")
test_df$subset <- "test"
tune_df$subset <- "tune"
train_df$subset <- "trainval"
df_all <- rbind(train_df, tune_df, test_df)
saveRDS(df_all, "data/train-processed.RDS")
# model matrix for structured effects
#mdl <- gam(target ~ site + sex + s(age_approx), family = "binomial", data = df)
mdl <- bam(
target ~ site + sex + s(age_approx, by=sex),
family = "binomial", data = df_all, discrete = TRUE, nthreads = 4
)
x_struc <- as.data.frame(model.matrix(mdl))
x_struc$target <- df_all$target
x_struc$image <- df_all$image
x_struc$patient_id <- df_all$patient_id
x_struc$subset <- df_all$subset
write.csv2(x_struc, "data/x_struc.csv")
import torch
import os
from tqdm import tqdm
import torchvision
images = []
files = []
tx = torchvision.transforms.Resize((128, 128))
for f in tqdm(os.listdir("train")):
img = torchvision.io.read_image("train/" + f)
images.append(tx(img.float() / 255))
files.append(f)
torch.save({
'names': files,
'images': torch.stack(images),
}, 'x_train_resized_normalized.pt') |
eventual data representation is in a single tableentry1 = po("torch_ingress_ltnsr") %>>% entry2 = po("torch_ingress_num") %>>% list(entry1, entry2) %>>% more fine-grained control looks something like thisgraph = Graph$new() graph$add_pipeop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
One of the strengths of
mlr3torch
is that it can easily handle multimodal data. This is because a neural network built out ofPipeOpTorch
operators can have multiple inputs (PipeOpTorchIngress
).To showcast this feature, we need a multimodal example dataset for which we can take this one: https://challenge2020.isic-archive.com/
Some predefined image tasks already exist in
mlr3torch
so integrating this new task will work similar to https://github.com/mlr-org/mlr3torch/blob/main/R/TaskClassif_mnist.R.To add a new task to
mlr3torch
, we need to add a function that takes in anID
and returns task.Then, we need to add this function to the dictionary of tasks as below:
Because the dataset is too large to be contained in the
mlr3torch
package, we use aDataBackendLazy
as the tasks's backend.Therefore, the
load_task_melanoma
function first needs to construct thisDataBackendLazy
and then create aTaskClassif
from thatDataBackendLazy
.The
DataBackendLazy
:The caching is also implemented via the private
cached()
function, so only the download and processing needs to be implemented..inst/col_info
folder and can be loaded using the privateload_column_info()
(https://github.com/mlr-org/mlr3torch/tree/main/inst/col_info). The code that can be used to generate this hardcoded metadata should be located in./data-raw
The text was updated successfully, but these errors were encountered: