Skip to content

Commit

Permalink
Fix factor map op for shortlist
Browse files Browse the repository at this point in the history
  • Loading branch information
rhenry-nv committed Jul 9, 2021
1 parent 22d13b3 commit 7be492a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/graph/node_operators_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1291,9 +1291,9 @@ size_t numLemmas_;
bool hasShortlist_;
public:
AddFactorMaxesOp(const std::vector<Expr>& nodes, bool hasShortlist, size_t groupStart, size_t numLemmas)
: NaryNodeOp(nodes, getShape(nodes, hasShortlist), commonType(std::vector<Expr>(nodes.begin() + 1, nodes.end())) ) {
: NaryNodeOp(nodes, getShape(nodes, hasShortlist), commonType(std::vector<Expr>(nodes.begin() + 1 + (int)hasShortlist, nodes.end())) ) {
groupStart_ = groupStart;
numLemmas_ = hasShortlist? nodes[1]->shape().size(): numLemmas;
numLemmas_ = numLemmas;
hasShortlist_ = hasShortlist;
}

Expand Down
2 changes: 1 addition & 1 deletion src/layers/generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ namespace marian {
} else {
auto numGroups = getNumFactorGroups();
if(numGroups > 1 && graph()->isInference() && graph()->getBackend()->getDeviceId().type == DeviceType::gpu) {
Expr shortlistIndices = shortlist? constant(shortlist->indices()) : nullptr;
Expr shortlistIndices = shortlist? indices(shortlist->indices()) : nullptr;
Expr lemmaHasFactorGroupTensor = getLemmaHasFactorGroupTensor();
std::vector<Expr> groupLosses(logits_.size());
std::transform(logits_.begin(), logits_.end(), groupLosses.begin(), [](const Ptr<RationalLoss>& loss) -> Expr {return loss->loss();});
Expand Down

1 comment on commit 7be492a

@hieuhoang
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should hold off on changes to shortlist. Big changes are coming to this code. I don't want you to waste your time

Please sign in to comment.