Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track broadcast axes in the shape_transform_descriptor #3610

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
7 changes: 6 additions & 1 deletion src/include/migraphx/shape_transform_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
// the axis. However, it still needs to accounted for. After we
// generate the broadcast we will set the axis to the hidden
// axis, and then length to 1.
optional<std::size_t> hidden_axis = nullopt;
std::vector<std::size_t> hidden_axis = {};

const std::vector<std::size_t>& origin_axis() const;
bool has_hidden_axis() const;

void add_split_axis(std::size_t i);

MIGRAPHX_EXPORT friend bool operator==(const sub& x, const sub& y);
MIGRAPHX_EXPORT friend bool operator!=(const sub& x, const sub& y);
Expand Down
178 changes: 150 additions & 28 deletions src/shape_transform_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/common_dims.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
#include <map>
#include <unordered_set>
#include <deque>

namespace migraphx {
Expand Down Expand Up @@ -99,6 +101,16 @@
return s.lens();
}

static dimension::sub* get_last_subdimension(std::vector<dimension>& dims)
{
if(dims.empty())
return {};
auto& d = dims.back();
if(d.subdimensions.empty())
return nullptr;
return &d.subdimensions.back();
}

bool shape_transform_descriptor::apply(const std::vector<operation>& ops)
{
std::vector<std::size_t> dims;
Expand Down Expand Up @@ -196,8 +208,7 @@
r += n;
transform(range(n + 1), std::back_inserter(new_dims), [&](auto j) -> dimension {
auto new_sub = sub;
if(not new_sub.axis.empty())
new_sub.axis.push_back(j);
new_sub.add_split_axis(j);
new_sub.len = start[j];
return {{new_sub}};
});
Expand All @@ -209,12 +220,20 @@
// Handle trailing 1s
if(new_dims.size() < rdims.size() and not new_dims.empty())
{
for(auto d : range(rdims.begin() + new_dims.size(), rdims.end()))
{
if(d != 1)
return false;
new_dims.push_back({{dimension::sub{1}}});
}
auto* sub = get_last_subdimension(new_dims);
auto axis = sub == nullptr ? std::vector<std::size_t>{} : sub->axis;
auto trailing_dims = range(rdims.begin() + new_dims.size(), rdims.end());
if(any_of(trailing_dims, [](auto d) { return d != 1; }))
return false;
if(distance(trailing_dims) > 1)
sub->add_split_axis(0);
transform(range(distance(trailing_dims)),
std::back_inserter(new_dims),
[&](std::size_t j) -> dimension {
dimension::sub s{1, axis};
s.add_split_axis(j + 1);
return {{s}};
});
}
assert(rdims.size() == new_dims.size());
if(rdims.size() != new_dims.size())
Expand Down Expand Up @@ -252,7 +271,20 @@
return dim;
if(dim.len() != 1)
MIGRAPHX_THROW("Wrong out_lens for broadcast");
return {{dimension::sub{len, {}}}};
auto new_subs = dim.subdimensions;
if(not new_subs.empty())
{
new_subs.front().len = len;
}
for(auto& s : new_subs)
{
if(not s.axis.empty())
{
s.hidden_axis = s.axis;
s.axis.clear();
}
}
return {new_subs};
});
std::transform(out_lens.begin() + offset + dimensions.size(),
out_lens.end(),
Expand Down Expand Up @@ -281,14 +313,19 @@
remove_1_sub_dims(subdimensions);
// Flatten adjacent dimensions
adjacent_for_each(subdimensions.begin(), subdimensions.end(), [&](sub& d1, sub& d2) {
if(d1.axis.size() < 2)
if(d1.origin_axis().size() < 2)
return;
if(d2.axis.size() < 2)
if(d2.origin_axis().size() < 2)
return;
if(not std::equal(d1.axis.begin(), d1.axis.end() - 1, d2.axis.begin(), d2.axis.end() - 1))
if(d1.has_hidden_axis() != d2.has_hidden_axis())
return;
auto a1 = d1.axis.back();
auto a2 = d2.axis.back();
if(not std::equal(d1.origin_axis().begin(),
d1.origin_axis().end() - 1,
d2.origin_axis().begin(),
d2.origin_axis().end() - 1))
return;
auto a1 = d1.origin_axis().back();
auto a2 = d2.origin_axis().back();
assert(a2 != a1);
if(a2 <= a1)
return;
Expand Down Expand Up @@ -355,7 +392,7 @@
if(d.subdimensions.empty())
d.subdimensions.push_back({1, {axis}});
else
d.subdimensions.front().hidden_axis = axis;
d.subdimensions.front().hidden_axis = {axis};
}

// Group all axes into a map with a key of the axis and the value is vector of
Expand All @@ -368,14 +405,77 @@
{
for(auto& s : d.subdimensions)
{
if(s.axis.empty())
if(s.origin_axis().empty())
continue;
axes_map[s.axis.front()].push_back(&s);
axes_map[s.origin_axis().front()].push_back(&s);
}
}
return axes_map;
}

static void set_origin_axis(dimension::sub& s, const std::vector<std::size_t>& axis)
{
if(s.has_hidden_axis())
s.hidden_axis = axis;
else
s.axis = axis;
}

// If an axis is split and some dimensions are hidden and others are not, then
// remove the hidden axis so only the non-hidden axis is used in
// simplificaiton

Check warning on line 426 in src/shape_transform_descriptor.cpp

View workflow job for this annotation

GitHub Actions / misspell

[misspell] src/shape_transform_descriptor.cpp#L426

"simplificaiton" is a misspelling of "simplification"
Raw output
./src/shape_transform_descriptor.cpp:426:3: "simplificaiton" is a misspelling of "simplification"
static void remove_split_hidden_axes(std::map<std::size_t, std::vector<dimension::sub*>>& axes_map)
{
for(auto&& p : axes_map)
{
auto& subs = p.second;
if(std::all_of(subs.begin(), subs.end(), [](const dimension::sub* s) {
return s->has_hidden_axis();
}))
continue;
for(auto* sub : subs)
{
if(not sub->has_hidden_axis())
continue;
sub->hidden_axis.clear();
}
// Remove the subdimesions that no longer have an axis
subs.erase(std::remove_if(subs.begin(),
subs.end(),
[](const dimension::sub* s) {
return s->axis.empty() and s->hidden_axis.empty();
}),
subs.end());
}
// Remove axis from group if empty
erase_if(axes_map, [](auto&& p) { return p.second.empty(); });
}

// If this is scalar, then remove all axes
static void remove_scalar_axis(std::vector<dimension>& dimensions)
{
dimension::sub* s = nullptr;
for(auto& d : dimensions)
{
auto has_axis = [](const dimension::sub& x) { return not x.origin_axis().empty(); };
auto it = std::find_if(d.subdimensions.begin(), d.subdimensions.end(), has_axis);
if(it == d.subdimensions.end())
continue;
if(s != nullptr)
return;
if(std::count_if(std::next(it), d.subdimensions.end(), has_axis) > 0)
return;
s = &*it;
}
if(s != nullptr)
{
if(s->has_hidden_axis())
s->hidden_axis.clear();
if(s->len == 1)
s->axis.clear();
}
}

// Renumber all axes while preserving the order of the axes
static void renumber_axes(std::map<std::size_t, std::vector<dimension::sub*>>& axes_map)
{
Expand All @@ -385,15 +485,15 @@
auto& subs = p.second;
if(subs.size() == 1)
{
subs[0]->axis = {axis};
set_origin_axis(*subs[0], {axis});
}
else
{
std::sort(subs.begin(), subs.end(), by(std::less<>{}, [](const dimension::sub* s) {
return s->axis;
return s->origin_axis();
}));
for(std::size_t i : range(subs.size()))
subs[i]->axis = {axis, i};
set_origin_axis(*subs[i], {axis, i});
}
}
}
Expand Down Expand Up @@ -437,6 +537,8 @@
for(auto& d : dimensions)
d.simplify();

remove_scalar_axis(dimensions);

std::map<std::size_t, std::size_t> missing_axes;
std::vector<std::size_t> last_axis;
{
Expand All @@ -445,6 +547,7 @@
if(axes_map.empty())
return;

remove_split_hidden_axes(axes_map);
renumber_axes(axes_map);

// Find last axis
Expand Down Expand Up @@ -611,17 +714,15 @@
if(s.axis.empty())
{
s.len = 1;
if(s.hidden_axis.has_value())
{
s.axis = {s.hidden_axis.value()};
s.hidden_axis = nullopt;
}
s.axis = s.hidden_axis;
s.hidden_axis.clear();
}
}

static operation make_reshape_unsqueeze(const std::vector<dimension::sub>& subs)
{
bool use_reshape = false;
std::unordered_set<std::size_t> all_1s;
// Check if split dimensions are all additional 1s
if(std::any_of(
subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() > 1; }))
Expand All @@ -645,6 +746,8 @@
// Number of elements that are 1
auto n1 =
std::count_if(start, last, [](const dimension::sub& s) { return s.len == 1; });
if(n == n1 and not start->axis.empty())
all_1s.insert(start->axis.front());
use_reshape |= std::max<int64_t>(0, n - n1 - 1) > 0;
},
by_axis);
Expand Down Expand Up @@ -672,6 +775,8 @@
continue;
if(sub.len != 1 and not sub.axis.empty())
continue;
if(not sub.axis.empty() and contains(all_1s, sub.axis.front()) and sub.axis.back() == 0)
continue;
axes.push_back(i);
}
return make_op("unsqueeze", {{"axes", axes}});
Expand All @@ -681,7 +786,7 @@
static bool has_no_axes(const dimension& d)
{
return std::all_of(d.subdimensions.begin(), d.subdimensions.end(), [](const dimension::sub& s) {
return s.axis.empty() and not s.hidden_axis.has_value();
return s.axis.empty() and s.hidden_axis.empty();
});
}
static bool has_axes(const dimension& d)
Expand Down Expand Up @@ -824,6 +929,23 @@
[](const auto& s) { return s.len(); });
}

const std::vector<std::size_t>& shape_transform_descriptor::dimension::sub::origin_axis() const
{
return axis.empty() ? hidden_axis : axis;
}
bool shape_transform_descriptor::dimension::sub::has_hidden_axis() const
{
return axis.empty() and not hidden_axis.empty();
}

void shape_transform_descriptor::dimension::sub::add_split_axis(std::size_t i)
{
if(not axis.empty())
axis.push_back(i);
if(not hidden_axis.empty())
hidden_axis.push_back(i);
}

bool operator==(const dimension::sub& x, const dimension::sub& y)
{
return by(std::equal_to<>{},
Expand All @@ -833,8 +955,8 @@
std::ostream& operator<<(std::ostream& os, const dimension::sub& x)
{
os << x.len << ":" << to_string_range(x.axis, "x");
if(x.hidden_axis.has_value())
os << "$" << x.hidden_axis.value();
if(not x.hidden_axis.empty())
os << "$" << to_string_range(x.hidden_axis, "x");
return os;
}
bool operator==(const dimension& x, const dimension& y)
Expand Down
3 changes: 3 additions & 0 deletions src/tf/parse_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ struct parse_conv : op_parser<parse_conv>
tf_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "data_format"))
std::cout << "data_format: " << info.attributes.at("data_format").s() << std::endl;

op::convolution op;
if(contains(info.attributes, "strides"))
{
Expand Down
Loading