From c2cecc0393d10d5ccc1b2b9842bd4f3aa5265b24 Mon Sep 17 00:00:00 2001 From: Julien Perrochet Date: Fri, 24 May 2024 08:37:37 +0200 Subject: [PATCH] [fix] add nil checks to avoid panic when creating an OIR without a subscription (#1039) --- pkg/models/models.go | 3 +++ pkg/models/models_test.go | 57 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/pkg/models/models.go b/pkg/models/models.go index cc954a42d..e7551f6e0 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -41,6 +41,9 @@ const ( // PgUUID converts an ID to a pgtype.UUID. // If the ID this is called on is nil, nil will be returned func (id *ID) PgUUID() (*pgtype.UUID, error) { + if id == nil { + return nil, nil + } pgUUID := pgtype.UUID{} err := (&pgUUID).Scan(id.String()) if err != nil { diff --git a/pkg/models/models_test.go b/pkg/models/models_test.go index 1a562205c..c20b71dd3 100644 --- a/pkg/models/models_test.go +++ b/pkg/models/models_test.go @@ -1,6 +1,12 @@ package models -import "testing" +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" +) func TestIDFromString(t *testing.T) { type args struct { @@ -82,3 +88,52 @@ func TestIDFromString(t *testing.T) { }) } } + +func scanIntoUUID(t *testing.T, id ID) *pgtype.UUID { + uuid := pgtype.UUID{} + err := uuid.Scan(id.String()) + assert.Nil(t, err) + return &uuid +} + +func TestID_PgUUID(t *testing.T) { + someID := ID("00000179-e36d-40be-838b-eca6ca350000") + badID := ID("00000179-e36d-40be-838b-eca6ca35") + tests := []struct { + name string + id *ID + want *pgtype.UUID + wantErr bool + }{ + { + name: "Ok", + id: &someID, + want: scanIntoUUID(t, someID), + wantErr: false, + }, + { + name: "Nil ID", + id: nil, + want: nil, + wantErr: false, + }, + { + name: "Bad UUID", + id: &badID, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.id.PgUUID() + if (err != nil) != tt.wantErr { + t.Errorf("PgUUID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("PgUUID() got = %v, want %v", got, tt.want) + } + }) + } +}