From 7be492aa446e22da5409c6a2c857fa79765ecf92 Mon Sep 17 00:00:00 2001 From: Rawn Henry Date: Thu, 8 Jul 2021 17:20:48 -0700 Subject: [PATCH] Fix factor map op for shortlist --- src/graph/node_operators_binary.h | 4 ++-- src/layers/generic.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 0d79986b1..474ead783 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -1291,9 +1291,9 @@ size_t numLemmas_; bool hasShortlist_; public: AddFactorMaxesOp(const std::vector& nodes, bool hasShortlist, size_t groupStart, size_t numLemmas) - : NaryNodeOp(nodes, getShape(nodes, hasShortlist), commonType(std::vector(nodes.begin() + 1, nodes.end())) ) { + : NaryNodeOp(nodes, getShape(nodes, hasShortlist), commonType(std::vector(nodes.begin() + 1 + (int)hasShortlist, nodes.end())) ) { groupStart_ = groupStart; - numLemmas_ = hasShortlist? nodes[1]->shape().size(): numLemmas; + numLemmas_ = numLemmas; hasShortlist_ = hasShortlist; } diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index c807dc060..5c9a9ccdb 100755 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -89,7 +89,7 @@ namespace marian { } else { auto numGroups = getNumFactorGroups(); if(numGroups > 1 && graph()->isInference() && graph()->getBackend()->getDeviceId().type == DeviceType::gpu) { - Expr shortlistIndices = shortlist? constant(shortlist->indices()) : nullptr; + Expr shortlistIndices = shortlist? indices(shortlist->indices()) : nullptr; Expr lemmaHasFactorGroupTensor = getLemmaHasFactorGroupTensor(); std::vector groupLosses(logits_.size()); std::transform(logits_.begin(), logits_.end(), groupLosses.begin(), [](const Ptr& loss) -> Expr {return loss->loss();});