-
Notifications
You must be signed in to change notification settings - Fork 229
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for Conv2D, elementwise add/subtract/multiply (#276)
- Loading branch information
Showing
13 changed files
with
588 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
from flexflow.core import * | ||
from flexflow.core.flexflow_cffi import Linear, Op, Parameter | ||
from flexflow.type import AggrMode | ||
|
||
sys.path.append("./align/") | ||
from align_ff_utils import (compile_ffmodel, init_ffmodel, run_fwd_bwd, | ||
save_param_ff, save_param_grad_ff, save_tensor_ff, | ||
save_tensor_grad_ff) | ||
from align_utils import BATCH_SIZE, gen_tensor | ||
|
||
OUT_DIR = os.path.join("align", "add", "out") | ||
|
||
|
||
def run(): | ||
INPUT_SIZE = 512 | ||
SEQ_LENGTH = 5 | ||
inp1: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE), | ||
dtype="float32" | ||
) | ||
inp2: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE), | ||
dtype="float32" | ||
) | ||
label: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE), | ||
dtype="float32" | ||
) | ||
|
||
ffconfig = FFConfig() | ||
ffmodel = FFModel(ffconfig) | ||
input_tensor_1 = ffmodel.create_tensor(inp1.shape, DataType.DT_FLOAT) | ||
input_tensor_2 = ffmodel.create_tensor(inp2.shape, DataType.DT_FLOAT) | ||
output_tensor = ffmodel.add( | ||
x=input_tensor_1, | ||
y=input_tensor_2, | ||
name="add" | ||
) | ||
|
||
# compile | ||
compile_ffmodel(ffmodel) | ||
dls = init_ffmodel(ffmodel, ((input_tensor_1, inp1), (input_tensor_2, inp2)), label) | ||
assert len(dls) == 3 | ||
inp1_dl, inp2_dl, label_dl = dls | ||
|
||
# forward/backward pass | ||
run_fwd_bwd(ffmodel, ffconfig, (inp1_dl, inp2_dl), label_dl) | ||
|
||
# save data | ||
save_tensor_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out.pt")) | ||
save_tensor_grad_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out_grad.pt")) | ||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
sys.path.append("./align/") | ||
from align_utils import gen_tensor, BATCH_SIZE | ||
|
||
assert torch.cuda.is_available(), "Expects at least one GPU" | ||
DEVICE = torch.device(0) | ||
OUT_DIR = os.path.join("align", "add", "out") | ||
|
||
def run(): | ||
INPUT_SIZE = 512 | ||
SEQ_LENGTH = 5 | ||
|
||
inp1: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE), | ||
dtype="float32" | ||
).to(DEVICE) | ||
inp2: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE), | ||
dtype="float32" | ||
).to(DEVICE) | ||
label: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE), | ||
dtype="float32" | ||
).to(DEVICE) | ||
output = torch.add( | ||
input=inp1, | ||
other=inp2 | ||
).to(DEVICE) | ||
output.requires_grad = True | ||
output.retain_grad() | ||
|
||
loss_fn = torch.nn.MSELoss(reduction="mean") | ||
loss = loss_fn(output, label) | ||
loss.backward() | ||
torch.save(output.cpu(), os.path.join(OUT_DIR, "torch_out.pt")) | ||
torch.save(output.grad.cpu(), os.path.join(OUT_DIR, "torch_out_grad.pt")) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
eval "$(conda shell.bash hook)"; | ||
rm align/add/out/*.pt; | ||
conda activate flexflow; | ||
./python/flexflow_python align/add/align_add_ff.py -ll:py 1 -ll:gpu 1 -ll:fsize 5000 -ll:zsize 4096 -b 16; | ||
conda activate pytorch; | ||
python align/add/align_add_torch.py; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
from flexflow.core import * | ||
from flexflow.core.flexflow_cffi import Linear, Op, Parameter | ||
from flexflow.type import AggrMode | ||
|
||
sys.path.append("./align/") | ||
from align_ff_utils import (compile_ffmodel, init_ffmodel, run_fwd_bwd, | ||
save_param_ff, save_param_grad_ff, save_tensor_ff, | ||
save_tensor_grad_ff) | ||
from align_utils import BATCH_SIZE, gen_tensor | ||
|
||
OUT_DIR = os.path.join("align", "conv2d", "out") | ||
|
||
|
||
def run(): | ||
KERNEL_SIZE = 3 | ||
INPUT_SIZE = 512 | ||
IN_CHANNELS = 3 | ||
OUTPUT_SIZE = 510 | ||
OUT_CHANNELS = 5 | ||
inp: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, IN_CHANNELS, INPUT_SIZE, INPUT_SIZE), | ||
dtype="float32" | ||
) | ||
label: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, OUT_CHANNELS, OUTPUT_SIZE, OUTPUT_SIZE), | ||
dtype="float32" | ||
) | ||
|
||
ffconfig = FFConfig() | ||
ffmodel = FFModel(ffconfig) | ||
input_tensor = ffmodel.create_tensor(inp.shape, DataType.DT_FLOAT) | ||
output_tensor = ffmodel.conv2d( | ||
input=input_tensor, | ||
out_channels=OUT_CHANNELS, | ||
kernel_h=KERNEL_SIZE, | ||
kernel_w=KERNEL_SIZE, | ||
stride_h=1, | ||
stride_w=1, | ||
padding_h=0, | ||
padding_w=0, | ||
name="conv2d" | ||
) | ||
|
||
# compile model | ||
compile_ffmodel(ffmodel) | ||
dls = init_ffmodel(ffmodel, ((input_tensor, inp),), label) | ||
assert len(dls) == 2 | ||
inp_dl, label_dl = dls | ||
|
||
# forward/back pass | ||
run_fwd_bwd(ffmodel, ffconfig, (inp_dl,), label_dl) | ||
|
||
conv2d_layer: Op = ffmodel.get_layers()[0] | ||
assert isinstance(conv2d_layer, Conv2D) | ||
conv2d_weight: Parameter = conv2d_layer.get_weight_tensor() | ||
conv2d_bias: Parameter = conv2d_layer.get_bias_tensor() | ||
|
||
# save output data | ||
save_tensor_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out.pt")) | ||
save_tensor_grad_ff(output_tensor, ffmodel, os.path.join(OUT_DIR, "ff_out_grad.pt")) | ||
|
||
# save layer data | ||
save_param_ff(conv2d_weight, ffmodel, os.path.join(OUT_DIR, "ff_weight.pt")) | ||
save_param_ff(conv2d_bias, ffmodel, os.path.join(OUT_DIR, "ff_bias.pt")) | ||
save_param_grad_ff(conv2d_weight, ffmodel, os.path.join(OUT_DIR, "ff_weight_grad.pt")) | ||
save_param_grad_ff(conv2d_bias, ffmodel, os.path.join(OUT_DIR, "ff_bias_grad.pt")) | ||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
sys.path.append("./align/") | ||
from align_utils import gen_tensor, BATCH_SIZE | ||
|
||
assert torch.cuda.is_available(), "Expects at least one GPU" | ||
DEVICE = torch.device(0) | ||
OUT_DIR = os.path.join("align", "conv2d", "out") | ||
|
||
def run(): | ||
KERNEL_SIZE = 3 | ||
INPUT_SIZE = 512 | ||
IN_CHANNELS = 3 | ||
OUTPUT_SIZE = 510 | ||
OUT_CHANNELS = 5 | ||
conv2d = torch.nn.Conv2d( | ||
in_channels=IN_CHANNELS, | ||
out_channels=OUT_CHANNELS, | ||
kernel_size=KERNEL_SIZE | ||
).to(DEVICE) | ||
|
||
linear_weight = torch.load(os.path.join(OUT_DIR, "ff_weight.pt")) | ||
linear_bias = torch.load(os.path.join(OUT_DIR, "ff_bias.pt")) | ||
assert conv2d.weight.shape == linear_weight.shape, ( | ||
"Shape mismatch: " f"FF={linear_weight.shape} torch={conv2d.weight.shape}" | ||
) | ||
assert conv2d.bias.shape == linear_bias.shape, ( | ||
"Shape mismatch: " f"FF={linear_bias.shape} torch={conv2d.bias.shape}" | ||
) | ||
|
||
conv2d.weight = torch.nn.Parameter(linear_weight.to(DEVICE)) | ||
conv2d.bias = torch.nn.Parameter(linear_bias.to(DEVICE)) | ||
|
||
# generate input/label tensors | ||
# imitating 3-channel image input | ||
inp: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE), | ||
dtype="float32" | ||
).to(DEVICE) | ||
label: torch.Tensor = gen_tensor( | ||
(BATCH_SIZE, 5, OUTPUT_SIZE, OUTPUT_SIZE), | ||
dtype="float32" | ||
).to(DEVICE) | ||
|
||
output = conv2d(inp) | ||
conv2d.zero_grad() | ||
output.retain_grad() | ||
loss_fn = torch.nn.MSELoss(reduction="mean") | ||
loss = loss_fn(output, label) | ||
loss.backward() | ||
torch.save(output.cpu(), os.path.join(OUT_DIR, "torch_out.pt")) | ||
torch.save(output.grad.cpu(), os.path.join(OUT_DIR, "torch_out_grad.pt")) | ||
torch.save(conv2d.weight.grad.cpu(), os.path.join(OUT_DIR, "torch_weight_grad.pt")) | ||
torch.save(conv2d.bias.grad.cpu(), os.path.join(OUT_DIR, "torch_bias_grad.pt")) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
eval "$(conda shell.bash hook)"; | ||
rm align/conv2d/out/*.pt; | ||
conda activate flexflow; | ||
./python/flexflow_python align/conv2d/align_conv2d_ff.py -ll:py 1 -ll:gpu 1 -ll:fsize 5000 -ll:zsize 4096 -b 16; | ||
conda activate pytorch; | ||
python align/conv2d/align_conv2d_torch.py; | ||
|
Oops, something went wrong.