Skip to content

Commit

Permalink
Add tests for Conv2D, elementwise add/subtract/multiply (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
kadinlz authored May 14, 2022
1 parent f717c75 commit 1e1191c
Show file tree
Hide file tree
Showing 13 changed files with 588 additions and 0 deletions.
61 changes: 61 additions & 0 deletions align/add/align_add_ff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import sys

import torch
from flexflow.core import *
from flexflow.core.flexflow_cffi import Linear, Op, Parameter
from flexflow.type import AggrMode

sys.path.append("./align/")
from align_ff_utils import (compile_ffmodel, init_ffmodel, run_fwd_bwd,
save_param_ff, save_param_grad_ff, save_tensor_ff,
save_tensor_grad_ff)
from align_utils import BATCH_SIZE, gen_tensor

OUT_DIR = os.path.join("align", "add", "out")


def run():
INPUT_SIZE = 512
SEQ_LENGTH = 5
inp1: torch.Tensor = gen_tensor(
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE),
dtype="float32"
)
inp2: torch.Tensor = gen_tensor(
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE),
dtype="float32"
)
label: torch.Tensor = gen_tensor(
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE),
dtype="float32"
)

ffconfig = FFConfig()
ffmodel = FFModel(ffconfig)
input_tensor_1 = ffmodel.create_tensor(inp1.shape, DataType.DT_FLOAT)
input_tensor_2 = ffmodel.create_tensor(inp2.shape, DataType.DT_FLOAT)
output_tensor = ffmodel.add(
x=input_tensor_1,
y=input_tensor_2,
name="add"
)

# compile
compile_ffmodel(ffmodel)
dls = init_ffmodel(ffmodel, ((input_tensor_1, inp1), (input_tensor_2, inp2)), label)
assert len(dls) == 3
inp1_dl, inp2_dl, label_dl = dls

# forward/backward pass
run_fwd_bwd(ffmodel, ffconfig, (inp1_dl, inp2_dl), label_dl)

# save data
save_tensor_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out.pt"))
save_tensor_grad_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out_grad.pt"))




if __name__ == "__main__":
run()
44 changes: 44 additions & 0 deletions align/add/align_add_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import sys

import torch

sys.path.append("./align/")
from align_utils import gen_tensor, BATCH_SIZE

assert torch.cuda.is_available(), "Expects at least one GPU"
DEVICE = torch.device(0)
OUT_DIR = os.path.join("align", "add", "out")

def run():
INPUT_SIZE = 512
SEQ_LENGTH = 5

inp1: torch.Tensor = gen_tensor(
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE),
dtype="float32"
).to(DEVICE)
inp2: torch.Tensor = gen_tensor(
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE),
dtype="float32"
).to(DEVICE)
label: torch.Tensor = gen_tensor(
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE),
dtype="float32"
).to(DEVICE)
output = torch.add(
input=inp1,
other=inp2
).to(DEVICE)
output.requires_grad = True
output.retain_grad()

loss_fn = torch.nn.MSELoss(reduction="mean")
loss = loss_fn(output, label)
loss.backward()
torch.save(output.cpu(), os.path.join(OUT_DIR, "torch_out.pt"))
torch.save(output.grad.cpu(), os.path.join(OUT_DIR, "torch_out_grad.pt"))


if __name__ == "__main__":
run()
7 changes: 7 additions & 0 deletions align/add/gen_tensors.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
eval "$(conda shell.bash hook)";
rm align/add/out/*.pt;
conda activate flexflow;
./python/flexflow_python align/add/align_add_ff.py -ll:py 1 -ll:gpu 1 -ll:fsize 5000 -ll:zsize 4096 -b 16;
conda activate pytorch;
python align/add/align_add_torch.py;

111 changes: 111 additions & 0 deletions align/align_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,114 @@ def test_getitem():
),
]
)

def test_linear():
out_dir = os.path.join(BASE_DIR, "linear", "out")
expand = prepend_dirname_fn(out_dir)
align_tensors(
[
TensorAlignmentData(
"linear_out",
expand("ff_out.pt"),
expand("torch_out.pt"),
),
TensorAlignmentData(
"linear_out_grad",
expand("ff_out_grad.pt"),
expand("torch_out_grad.pt"),
),
TensorAlignmentData(
"linear_weight_grad",
expand("ff_weight_grad.pt"),
expand("torch_weight_grad.pt"),
),
TensorAlignmentData(
"linear_bias_grad",
expand("ff_bias_grad.pt"),
expand("torch_bias_grad.pt")
)
]
)

def test_conv2d():
out_dir = os.path.join(BASE_DIR, "conv2d", "out")
expand = prepend_dirname_fn(out_dir)
align_tensors(
[
TensorAlignmentData(
"conv2d_out",
expand("ff_out.pt"),
expand("torch_out.pt"),
),
TensorAlignmentData(
"conv2d_out_grad",
expand("ff_out_grad.pt"),
expand("torch_out_grad.pt"),
),
TensorAlignmentData(
"conv2d_weight_grad",
expand("ff_weight_grad.pt"),
expand("torch_weight_grad.pt"),
),
TensorAlignmentData(
"conv2d_bias_grad",
expand("ff_bias_grad.pt"),
expand("torch_bias_grad.pt")
)
]
)


def test_add():
out_dir = os.path.join(BASE_DIR, "add", "out")
expand = prepend_dirname_fn(out_dir)
align_tensors(
[
TensorAlignmentData(
"add_out",
expand("ff_out.pt"),
expand("torch_out.pt"),
),
TensorAlignmentData(
"add_out_grad",
expand("ff_out_grad.pt"),
expand("torch_out_grad.pt"),
),
]
)

def test_subtract():
out_dir = os.path.join(BASE_DIR, "subtract", "out")
expand = prepend_dirname_fn(out_dir)
align_tensors(
[
TensorAlignmentData(
"subtract_out",
expand("ff_out.pt"),
expand("torch_out.pt"),
),
TensorAlignmentData(
"subtract_out_grad",
expand("ff_out_grad.pt"),
expand("torch_out_grad.pt"),
),
]
)

def test_multiply():
out_dir = os.path.join(BASE_DIR, "multiply", "out")
expand = prepend_dirname_fn(out_dir)
align_tensors(
[
TensorAlignmentData(
"multiply_out",
expand("ff_out.pt"),
expand("torch_out.pt"),
),
TensorAlignmentData(
"multiply_out_grad",
expand("ff_out_grad.pt"),
expand("torch_out_grad.pt"),
),
]
)
73 changes: 73 additions & 0 deletions align/conv2d/align_conv2d_ff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import sys

import torch
from flexflow.core import *
from flexflow.core.flexflow_cffi import Linear, Op, Parameter
from flexflow.type import AggrMode

sys.path.append("./align/")
from align_ff_utils import (compile_ffmodel, init_ffmodel, run_fwd_bwd,
save_param_ff, save_param_grad_ff, save_tensor_ff,
save_tensor_grad_ff)
from align_utils import BATCH_SIZE, gen_tensor

OUT_DIR = os.path.join("align", "conv2d", "out")


def run():
KERNEL_SIZE = 3
INPUT_SIZE = 512
IN_CHANNELS = 3
OUTPUT_SIZE = 510
OUT_CHANNELS = 5
inp: torch.Tensor = gen_tensor(
(BATCH_SIZE, IN_CHANNELS, INPUT_SIZE, INPUT_SIZE),
dtype="float32"
)
label: torch.Tensor = gen_tensor(
(BATCH_SIZE, OUT_CHANNELS, OUTPUT_SIZE, OUTPUT_SIZE),
dtype="float32"
)

ffconfig = FFConfig()
ffmodel = FFModel(ffconfig)
input_tensor = ffmodel.create_tensor(inp.shape, DataType.DT_FLOAT)
output_tensor = ffmodel.conv2d(
input=input_tensor,
out_channels=OUT_CHANNELS,
kernel_h=KERNEL_SIZE,
kernel_w=KERNEL_SIZE,
stride_h=1,
stride_w=1,
padding_h=0,
padding_w=0,
name="conv2d"
)

# compile model
compile_ffmodel(ffmodel)
dls = init_ffmodel(ffmodel, ((input_tensor, inp),), label)
assert len(dls) == 2
inp_dl, label_dl = dls

# forward/back pass
run_fwd_bwd(ffmodel, ffconfig, (inp_dl,), label_dl)

conv2d_layer: Op = ffmodel.get_layers()[0]
assert isinstance(conv2d_layer, Conv2D)
conv2d_weight: Parameter = conv2d_layer.get_weight_tensor()
conv2d_bias: Parameter = conv2d_layer.get_bias_tensor()

# save output data
save_tensor_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out.pt"))
save_tensor_grad_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out_grad.pt"))

# save layer data
save_param_ff(conv2d_weight, ffmodel, os.path.join(OUT_DIR, "ff_weight.pt"))
save_param_ff(conv2d_bias, ffmodel, os.path.join(OUT_DIR, "ff_bias.pt"))
save_param_grad_ff(conv2d_weight, ffmodel, os.path.join(OUT_DIR, "ff_weight_grad.pt"))
save_param_grad_ff(conv2d_bias, ffmodel, os.path.join(OUT_DIR, "ff_bias_grad.pt"))

if __name__ == "__main__":
run()
61 changes: 61 additions & 0 deletions align/conv2d/align_conv2d_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import sys

import torch

sys.path.append("./align/")
from align_utils import gen_tensor, BATCH_SIZE

assert torch.cuda.is_available(), "Expects at least one GPU"
DEVICE = torch.device(0)
OUT_DIR = os.path.join("align", "conv2d", "out")

def run():
KERNEL_SIZE = 3
INPUT_SIZE = 512
IN_CHANNELS = 3
OUTPUT_SIZE = 510
OUT_CHANNELS = 5
conv2d = torch.nn.Conv2d(
in_channels=IN_CHANNELS,
out_channels=OUT_CHANNELS,
kernel_size=KERNEL_SIZE
).to(DEVICE)

linear_weight = torch.load(os.path.join(OUT_DIR, "ff_weight.pt"))
linear_bias = torch.load(os.path.join(OUT_DIR, "ff_bias.pt"))
assert conv2d.weight.shape == linear_weight.shape, (
"Shape mismatch: " f"FF={linear_weight.shape} torch={conv2d.weight.shape}"
)
assert conv2d.bias.shape == linear_bias.shape, (
"Shape mismatch: " f"FF={linear_bias.shape} torch={conv2d.bias.shape}"
)

conv2d.weight = torch.nn.Parameter(linear_weight.to(DEVICE))
conv2d.bias = torch.nn.Parameter(linear_bias.to(DEVICE))

# generate input/label tensors
# imitating 3-channel image input
inp: torch.Tensor = gen_tensor(
(BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE),
dtype="float32"
).to(DEVICE)
label: torch.Tensor = gen_tensor(
(BATCH_SIZE, 5, OUTPUT_SIZE, OUTPUT_SIZE),
dtype="float32"
).to(DEVICE)

output = conv2d(inp)
conv2d.zero_grad()
output.retain_grad()
loss_fn = torch.nn.MSELoss(reduction="mean")
loss = loss_fn(output, label)
loss.backward()
torch.save(output.cpu(), os.path.join(OUT_DIR, "torch_out.pt"))
torch.save(output.grad.cpu(), os.path.join(OUT_DIR, "torch_out_grad.pt"))
torch.save(conv2d.weight.grad.cpu(), os.path.join(OUT_DIR, "torch_weight_grad.pt"))
torch.save(conv2d.bias.grad.cpu(), os.path.join(OUT_DIR, "torch_bias_grad.pt"))


if __name__ == "__main__":
run()
7 changes: 7 additions & 0 deletions align/conv2d/gen_tensors.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
eval "$(conda shell.bash hook)";
rm align/conv2d/out/*.pt;
conda activate flexflow;
./python/flexflow_python align/conv2d/align_conv2d_ff.py -ll:py 1 -ll:gpu 1 -ll:fsize 5000 -ll:zsize 4096 -b 16;
conda activate pytorch;
python align/conv2d/align_conv2d_torch.py;

Loading

0 comments on commit 1e1191c

Please sign in to comment.