diff --git a/app/controllers/CheckSalesController.scala b/app/controllers/CheckSalesController.scala index 9e6a727a..093e17c9 100644 --- a/app/controllers/CheckSalesController.scala +++ b/app/controllers/CheckSalesController.scala @@ -19,7 +19,7 @@ package controllers import controllers.actions._ import forms.CheckSalesFormProvider import models.Index -import pages.{CheckSalesPage, Waypoints} +import pages.{CheckSalesPage, SalesToCountryPage, Waypoints} import play.api.data.Form import play.api.i18n.{I18nSupport, MessagesApi} import play.api.mvc.{Action, AnyContent, MessagesControllerComponents} @@ -85,28 +85,36 @@ class CheckSalesController @Inject()( def onSubmit(waypoints: Waypoints, countryIndex: Index, incompletePromptShown: Boolean): Action[AnyContent] = cc.authAndRequireData().async { implicit request => - withCompleteDataAsync[VatRateWithOptionalSalesFromCountry]( - countryIndex, - data = getIncompleteVatRateAndSales _, - onFailure = (_: Seq[VatRateWithOptionalSalesFromCountry]) => { - if(incompletePromptShown) { - Redirect(routes.SalesToCountryController.onPageLoad(waypoints, countryIndex, Index(0))).toFuture - } else { - Redirect(routes.CheckSalesController.onPageLoad(waypoints, countryIndex)).toFuture - } - }) { + getCountry(waypoints, countryIndex) { country => + getAllVatRatesFromCountry(waypoints, countryIndex) { vatRates => + + val period = request.userAnswers.period - getCountry(waypoints, countryIndex) { country => - getAllVatRatesFromCountry(waypoints, countryIndex) { vatRates => + val vatRateIndex = Index(vatRates.vatRatesFromCountry.map(_.size).getOrElse(0) - 1) - val period = request.userAnswers.period + val remainingVatRates = vatRateService.getRemainingVatRatesForCountry(period, country, vatRates) - val remainingVatRates = vatRateService.getRemainingVatRatesForCountry(period, country, vatRates) + val canAddAnotherVatRate = remainingVatRates.nonEmpty - val canAddAnotherVatRate = remainingVatRates.nonEmpty + val checkSalesSummary = CheckSalesSummary.rows(request.userAnswers, waypoints, countryIndex) - val checkSalesSummary = CheckSalesSummary.rows(request.userAnswers, waypoints, countryIndex) + val salesToCountry = request.userAnswers.get(SalesToCountryPage(countryIndex, vatRateIndex)) + withCompleteDataAsync[VatRateWithOptionalSalesFromCountry]( + countryIndex, + data = getIncompleteVatRateAndSales _, + onFailure = (_: Seq[VatRateWithOptionalSalesFromCountry]) => { + if(incompletePromptShown) { + salesToCountry match { + case Some(_) => + Redirect(routes.VatOnSalesController.onPageLoad(waypoints, countryIndex, vatRateIndex)).toFuture + case None => + Redirect(routes.SalesToCountryController.onPageLoad(waypoints, countryIndex, vatRateIndex)).toFuture + } + } else { + Redirect(routes.CheckSalesController.onPageLoad(waypoints, countryIndex)).toFuture + } + }) { form.bindFromRequest().fold( formWithErrors => BadRequest(view(formWithErrors, waypoints, period, checkSalesSummary, countryIndex, country, canAddAnotherVatRate, Seq.empty)).toFuture,