Skip to content

Commit

Permalink
Merge pull request #131230 from cockroachdb/blathers/backport-release…
Browse files Browse the repository at this point in the history
…-24.2.3-rc-131209

release-24.2.3-rc: sqlliveness: detect and handle invalid SessionIDs
  • Loading branch information
RaduBerinde authored Sep 23, 2024
2 parents b4abaa2 + a5db6e0 commit 217b43e
Show file tree
Hide file tree
Showing 15 changed files with 138 additions and 22 deletions.
7 changes: 7 additions & 0 deletions pkg/ccl/logictestccl/tests/3node-tenant/generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pkg/sql/catalog/lease/lease_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1709,8 +1709,8 @@ func TestLeaseCountDetailSessionBased(t *testing.T) {
version := 1
region := enum.One
_, err := executor.Exec(ctx, "add-rows-for-test", nil,
fmt.Sprintf("INSERT INTO system.lease VALUES (%d, %d, %s, '%s', '\\x%x')",
descID, version, nodeID, session.ID(), region))
fmt.Sprintf("INSERT INTO system.lease VALUES (%d, %d, %s, '\\x%x', '\\x%x')",
descID, version, nodeID, session.ID().UnsafeBytes(), region))
if err != nil {
return err
}
Expand Down
36 changes: 36 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/sqlliveness
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Validate that invalid sessionID's are always
# considered dead.
subtest invalid_sessions

# Legacy non-RBR format
query B
select crdb_internal.sql_liveness_is_alive(x'1f915e98f96145a5baa9f3a42c378eb6');
----
false

# Wrong length
query B
select crdb_internal.sql_liveness_is_alive(x'deadbeef');
----
false

subtest end


subtest valid_sessions

# Sanity: All sessions are alive in sqlliveness.
query I
SELECT count(*) FROM system.sqlliveness WHERE crdb_internal.sql_liveness_is_alive(session_id) = false;
----
0

query B
SELECT count(*) > 0 FROM system.sqlliveness WHERE crdb_internal.sql_liveness_is_alive(session_id) = true;
----
true

subtest end



7 changes: 7 additions & 0 deletions pkg/sql/logictest/tests/fakedist-disk/generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pkg/sql/logictest/tests/fakedist-vec-off/generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pkg/sql/logictest/tests/fakedist/generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pkg/sql/logictest/tests/local-mixed-23.2/generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pkg/sql/logictest/tests/local-vec-off/generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pkg/sql/logictest/tests/local/generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pkg/sql/schemachanger/comparator_generated_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pkg/sql/sqlliveness/slstorage/key_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
type keyCodec interface {
encode(sid sqlliveness.SessionID) (roachpb.Key, string, error)
decode(key roachpb.Key) (sqlliveness.SessionID, error)
validate(session sqlliveness.SessionID) error

// indexPrefix returns the prefix for an encoded key. encode() will return
// something with the prefix and decode will expect a key with this prefix.
Expand All @@ -37,6 +38,10 @@ type rbrEncoder struct {
rbrIndex roachpb.Key
}

func (e *rbrEncoder) validate(session sqlliveness.SessionID) error {
return ValidateSessionID(session)
}

func (e *rbrEncoder) encode(session sqlliveness.SessionID) (roachpb.Key, string, error) {
region, _, err := SafeDecodeSessionID(session)
if err != nil {
Expand Down
34 changes: 15 additions & 19 deletions pkg/sql/sqlliveness/slstorage/sessionid.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,8 @@ func MakeSessionID(region []byte, id uuid.UUID) (sqlliveness.SessionID, error) {
// not be mutated.
func UnsafeDecodeSessionID(session sqlliveness.SessionID) (region, id []byte, err error) {
b := session.UnsafeBytes()
if len(b) == legacyLen {
return nil, nil, errors.Newf("unexpected legacy SessionID format")
}
if len(b) < minimumNonLegacyLen {
// The smallest valid v1 session id is a [version, 1, single_byte_region, uuid...],
// which is three bytes larger than a uuid.
return nil, nil, errors.New("session id is too short")
}

// Decode the version.
if b[0] != sessionIDVersion {
return nil, nil, errors.Newf("invalid session id version: %d", b[0])
if err = ValidateSessionID(sqlliveness.SessionID(b)); err != nil {
return nil, nil, err
}
regionLen := int(b[1])
rest := b[2:]
Expand All @@ -91,24 +81,30 @@ func UnsafeDecodeSessionID(session sqlliveness.SessionID) (region, id []byte, er
return rest[:regionLen], rest[regionLen:], nil
}

// SafeDecodeSessionID decodes the region and id from the SessionID.
func SafeDecodeSessionID(session sqlliveness.SessionID) (region, id string, err error) {
// ValidateSessionID validates that the SessionID has the correct format.
func ValidateSessionID(session sqlliveness.SessionID) error {
if len(session) == legacyLen {
return "", "", errors.Newf("unexpected legacy SessionID format")
return errors.Newf("unexpected legacy SessionID format")
}
if len(session) < minimumNonLegacyLen {
// The smallest valid v1 session id is a [version, 1, single_byte_region, uuid...],
// which is three bytes larger than a uuid.
return "", "", errors.New("session id is too short")
return errors.New("session id is too short")
}

// Decode the version.
if session[0] != sessionIDVersion {
return "", "", errors.Newf("invalid session id version: %d", session[0])
return errors.Newf("invalid session id version: %d", session[0])
}
return nil
}

// SafeDecodeSessionID decodes the region and id from the SessionID.
func SafeDecodeSessionID(session sqlliveness.SessionID) (region, id string, err error) {
if err = ValidateSessionID(session); err != nil {
return "", "", err
}
regionLen := int(session[1])
rest := session[2:]

// Decode and validate the length of the region.
if len(rest) != regionLen+uuid.Size {
return "", "", errors.Newf("session id with length %d is the wrong size to include a region with length %d", len(session), regionLen)
Expand Down
13 changes: 12 additions & 1 deletion pkg/sql/sqlliveness/slstorage/slstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,15 @@ const (
func (s *Storage) isAlive(
ctx context.Context, sid sqlliveness.SessionID, syncOrAsync readType,
) (alive bool, _ error) {

// Confirm the session ID has the correct format, and if it
// doesn't then we can consider it as dead without any extra
// work.
if err := s.keyCodec.validate(sid); err != nil {
// This SessionID may be invalid because of the wrong format
// so consider it as dead.
//nolint:returnerrcheck
return false, nil
}
// If wait is false, alive is set and future is unset.
// If wait is true, alive is unset and future is set.
alive, wait, future, err := func() (bool, bool, singleflight.Future, error) {
Expand Down Expand Up @@ -318,6 +326,9 @@ func (s *Storage) deleteOrFetchSession(
ctx = multitenant.WithTenantCostControlExemption(ctx)
livenessProber := regionliveness.NewLivenessProber(s.db, s.codec, nil, s.settings)
k, regionPhysicalRep, err := s.keyCodec.encode(sid)
if err != nil {
return false, hlc.Timestamp{}, err
}
if err := s.txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
// Reset captured variable in case of retry.
deleted, expiration, prevExpiration = false, hlc.Timestamp{}, hlc.Timestamp{}
Expand Down

0 comments on commit 217b43e

Please sign in to comment.