Skip to content

Commit

Permalink
Fix return stmt cast to 1 when it's not a scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 21, 2024
1 parent 770062c commit 7e8c08d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const Expr* value = RS->getRetValue();
QualType type = value->getType();
auto* dfdf = m_Pullback;
if (dfdf && (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf))) {
if (dfdf && (isa<FloatingLiteral>(dfdf) || isa<IntegerLiteral>(dfdf)) &&
type->isScalarType()) {
ExprResult tmp = dfdf;
dfdf = m_Sema
.ImpCastExprToType(tmp.get(), type,
Expand Down Expand Up @@ -1277,6 +1278,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) {
if (!dfdx())
return StmtDiff(Clone(ILE));
QualType ILEType = ILE->getType();
llvm::SmallVector<Expr*, 16> clonedExprs(ILE->getNumInits());
if (isArrayOrPointerType(ILEType)) {
Expand All @@ -1302,12 +1305,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto field_iterator = ILEType->getAsCXXRecordDecl()->field_begin();
std::advance(field_iterator, i);
Expr* member_acess = nullptr;
if (dfdx())
member_acess = utils::BuildMemberExpr(
m_Sema, getCurrentScope(), dfdx(), (*field_iterator)->getName());
member_acess = utils::BuildMemberExpr(m_Sema, getCurrentScope(), dfdx(),
(*field_iterator)->getName());
clonedExprs[i] = Visit(ILE->getInit(i), member_acess).getExpr();
}
Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get();
printf("before cloning\n");
return StmtDiff(clonedILE);
}

Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,4 @@ int main() {
TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00}
TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00}
TEST_DIFFERENTIATE(fnTuple1, 3, 4); // CHECK-EXEC: {2.00}
}
}

0 comments on commit 7e8c08d

Please sign in to comment.