Skip to content

Commit

Permalink
introduce num_tricks in bridge state and change order or tricks
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiggerZZ committed Sep 18, 2023
1 parent cba58ac commit ab14e62
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 35 deletions.
50 changes: 36 additions & 14 deletions open_spiel/games/bridge/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ const GameType kGameType{/*short_name=*/"bridge",
{"dealer_vul", GameParameter(false)},
// If true, the non-dealer's side is vulnerable.
{"non_dealer_vul", GameParameter(false)},
// Number of played tricks in observation tensor
{"num_tricks", GameParameter(true)},
}};

std::shared_ptr<const Game> Factory(const GameParameters& params) {
Expand Down Expand Up @@ -130,10 +132,12 @@ BridgeGame::BridgeGame(const GameParameters& params)
BridgeState::BridgeState(std::shared_ptr<const Game> game,
bool use_double_dummy_result,
bool is_dealer_vulnerable,
bool is_non_dealer_vulnerable)
bool is_non_dealer_vulnerable,
int num_tricks)
: State(game),
use_double_dummy_result_(use_double_dummy_result),
is_vulnerable_{is_dealer_vulnerable, is_non_dealer_vulnerable} {
is_vulnerable_{is_dealer_vulnerable, is_non_dealer_vulnerable},
num_tricks_(num_tricks) {
possible_contracts_.fill(true);
}

Expand Down Expand Up @@ -337,17 +341,6 @@ void BridgeState::WriteObservationTensor(Player player,
int this_trick_cards_played = num_cards_played_ % kNumPlayers;
int this_trick_start = history_.size() - this_trick_cards_played;

// Previous trick.
if (current_trick > 0) {
int leader = tricks_[current_trick - 1].Leader();
for (int i = 0; i < kNumPlayers; ++i) {
int card = history_[this_trick_start - kNumPlayers + i].action;
int relative_player = (i + leader + kNumPlayers - player) % kNumPlayers;
ptr[relative_player * kNumCards + card] = 1;
}
}
ptr += kNumPlayers * kNumCards;

// Current trick
if (phase_ != Phase::kGameOver) {
int leader = tricks_[current_trick].Leader();
Expand All @@ -357,13 +350,42 @@ void BridgeState::WriteObservationTensor(Player player,
ptr[relative_player * kNumCards + card] = 1;
}
}

ptr += kNumPlayers * kNumCards;

// Previous tricks
for (int j = current_trick - 1; j >= std::max(0, current_trick - num_tricks_ + 1); --j) {
int leader = tricks_[j].Leader();
for (int i = 0; i < kNumPlayers; ++i) {
int card = history_[this_trick_start - kNumPlayers * (current_trick - j) + i].action;
int relative_player = (i + leader + kNumPlayers - player) % kNumPlayers;
ptr[relative_player * kNumCards + card] = 1;
}
ptr += kNumPlayers * kNumCards;
}

// Move pointer for future tricks to have a fixed size tensor
if (num_tricks_ > current_trick + 1) {
ptr += kNumPlayers * kNumCards * (num_tricks_ - current_trick - 1);
}

// Number of tricks taken by each side.
ptr[num_declarer_tricks_] = 1;
ptr += kNumTricks;
ptr[num_cards_played_ / 4 - num_declarer_tricks_] = 1;
ptr += kNumTricks;

int kPlayTensorSize =
kNumBidLevels // What the contract is
+ kNumDenominations // What trumps are
+ kNumOtherCalls // Undoubled / doubled / redoubled
+ kNumPlayers // Who declarer is
+ kNumVulnerabilities // Vulnerability of the declaring side
+ kNumCards // Our remaining cards
+ kNumCards // Dummy's remaining cards
+ num_tricks_ * kNumPlayers * kNumCards // Number of played tricks
+ kNumTricks // Number of tricks we have won
+ kNumTricks; // Number of tricks they have won
SPIEL_CHECK_EQ(std::distance(values.begin(), ptr),
kPlayTensorSize + kNumObservationTypes);
SPIEL_CHECK_LE(std::distance(values.begin(), ptr), values.size());
Expand Down Expand Up @@ -888,7 +910,7 @@ std::unique_ptr<State> BridgeGame::DeserializeState(
if (!UseDoubleDummyResult()) return Game::DeserializeState(str);
auto state = absl::make_unique<BridgeState>(
shared_from_this(), UseDoubleDummyResult(), IsDealerVulnerable(),
IsNonDealerVulnerable());
IsNonDealerVulnerable(), NumTricks());
std::vector<std::string> lines = absl::StrSplit(str, '\n');
const auto separator = absl::c_find(lines, "Double Dummy Results");
// Double-dummy results.
Expand Down
41 changes: 24 additions & 17 deletions open_spiel/games/bridge/bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,6 @@ inline constexpr int kPublicInfoTensorSize =
kAuctionTensorSize // The auction
- kNumCards // But not any player's cards
+ kNumPlayers; // Plus trailing passes
inline constexpr int kPlayTensorSize =
kNumBidLevels // What the contract is
+ kNumDenominations // What trumps are
+ kNumOtherCalls // Undoubled / doubled / redoubled
+ kNumPlayers // Who declarer is
+ kNumVulnerabilities // Vulnerability of the declaring side
+ kNumCards // Our remaining cards
+ kNumCards // Dummy's remaining cards
+ kNumPlayers * kNumCards // Cards played to the previous trick
+ kNumPlayers * kNumCards // Cards played to the current trick
+ kNumTricks // Number of tricks we have won
+ kNumTricks; // Number of tricks they have won
inline constexpr int kObservationTensorSize =
kNumObservationTypes + std::max(kPlayTensorSize, kAuctionTensorSize);
inline constexpr int kMaxAuctionLength =
kNumBids * (1 + kNumPlayers * 2) + kNumPlayers;
inline constexpr Player kFirstPlayer = 0;
Expand Down Expand Up @@ -115,7 +101,7 @@ class Trick {
class BridgeState : public State {
public:
BridgeState(std::shared_ptr<const Game> game, bool use_double_dummy_result,
bool is_dealer_vulnerable, bool is_non_dealer_vulnerable);
bool is_dealer_vulnerable, bool is_non_dealer_vulnerable, int num_tricks);
Player CurrentPlayer() const override;
std::string ActionToString(Player player, Action action) const override;
std::string ToString() const override;
Expand Down Expand Up @@ -193,6 +179,7 @@ class BridgeState : public State {

const bool use_double_dummy_result_;
const bool is_vulnerable_[kNumPartnerships];
const int num_tricks_;

int num_passes_ = 0; // Number of consecutive passes since the last non-pass.
int num_declarer_tricks_ = 0;
Expand Down Expand Up @@ -221,14 +208,31 @@ class BridgeGame : public Game {
std::unique_ptr<State> NewInitialState() const override {
return std::unique_ptr<State>(
new BridgeState(shared_from_this(), UseDoubleDummyResult(),
IsDealerVulnerable(), IsNonDealerVulnerable()));
IsDealerVulnerable(), IsNonDealerVulnerable(), NumTricks()));
}
int NumPlayers() const override { return kNumPlayers; }
double MinUtility() const override { return -kMaxScore; }
double MaxUtility() const override { return kMaxScore; }
absl::optional<double> UtilitySum() const override { return 0; }

int GetObservationTensorSize(int num_tricks) const {
int kPlayTensorSize =
kNumBidLevels // What the contract is
+ kNumDenominations // What trumps are
+ kNumOtherCalls // Undoubled / doubled / redoubled
+ kNumPlayers // Who declarer is
+ kNumVulnerabilities // Vulnerability of the declaring side
+ kNumCards // Our remaining cards
+ kNumCards // Dummy's remaining cards
+ num_tricks * kNumPlayers * kNumCards // Number of played tricks
+ kNumTricks // Number of tricks we have won
+ kNumTricks; // Number of tricks they have won
int kObservationTensorSize = kNumObservationTypes + std::max(kPlayTensorSize, kAuctionTensorSize);
return kObservationTensorSize;
}

std::vector<int> ObservationTensorShape() const override {
return {kObservationTensorSize};
return {GetObservationTensorSize(NumTricks())};
}
int MaxGameLength() const override {
return UseDoubleDummyResult() ? kMaxAuctionLength
Expand Down Expand Up @@ -259,6 +263,9 @@ class BridgeGame : public Game {
bool IsNonDealerVulnerable() const {
return ParameterValue<bool>("non_dealer_vul", false);
}
int NumTricks() const {
return ParameterValue<int>("num_tricks", 2);
}
};

} // namespace bridge
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ GameType.information = Information.IMPERFECT_INFORMATION
GameType.long_name = "Contract Bridge"
GameType.max_num_players = 4
GameType.min_num_players = 4
GameType.parameter_specification = ["dealer_vul", "non_dealer_vul", "use_double_dummy_result"]
GameType.parameter_specification = ["dealer_vul", "non_dealer_vul", "num_tricks", "use_double_dummy_result"]
GameType.provides_information_state_string = False
GameType.provides_information_state_tensor = False
GameType.provides_observation_string = True
Expand All @@ -19,7 +19,7 @@ GameType.utility = Utility.ZERO_SUM
NumDistinctActions() = 90
PolicyTensorShape() = [90]
MaxChanceOutcomes() = 52
GetParameters() = {dealer_vul=False,non_dealer_vul=False,use_double_dummy_result=False}
GetParameters() = {dealer_vul=False,non_dealer_vul=False,num_tricks=2,use_double_dummy_result=False}
NumPlayers() = 4
MinUtility() = -7600.0
MaxUtility() = 7600.0
Expand Down
4 changes: 2 additions & 2 deletions open_spiel/integration_tests/playthroughs/bridge.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ GameType.information = Information.IMPERFECT_INFORMATION
GameType.long_name = "Contract Bridge"
GameType.max_num_players = 4
GameType.min_num_players = 4
GameType.parameter_specification = ["dealer_vul", "non_dealer_vul", "use_double_dummy_result"]
GameType.parameter_specification = ["dealer_vul", "non_dealer_vul", "num_tricks", "use_double_dummy_result"]
GameType.provides_information_state_string = False
GameType.provides_information_state_tensor = False
GameType.provides_observation_string = True
Expand All @@ -19,7 +19,7 @@ GameType.utility = Utility.ZERO_SUM
NumDistinctActions() = 90
PolicyTensorShape() = [90]
MaxChanceOutcomes() = 52
GetParameters() = {dealer_vul=False,non_dealer_vul=False,use_double_dummy_result=True}
GetParameters() = {dealer_vul=False,non_dealer_vul=False,num_tricks=2,use_double_dummy_result=True}
NumPlayers() = 4
MinUtility() = -7600.0
MaxUtility() = 7600.0
Expand Down

0 comments on commit ab14e62

Please sign in to comment.