Skip to content

Commit

Permalink
databaseがctxを受けるようにする
Browse files Browse the repository at this point in the history
  • Loading branch information
Azuki-bar committed Nov 21, 2023
1 parent 0f6e1e4 commit 67d8832
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 90 deletions.
12 changes: 8 additions & 4 deletions backend/onetime/seed-data/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ func main() {
if err := godotenv.Load(".env"); err != nil {
panic(err)
}
db, err := dbhandler.Open(context.TODO(), options.Client().ApplyURI(os.Getenv("MONGODB_URI")))
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
db, err := dbhandler.Open(ctx, options.Client().ApplyURI(os.Getenv("MONGODB_URI")))
if err != nil {
cancel(err)
return
}
defer db.Close()
defer db.Close(ctx)
data := &Seed{}
b, _ := os.ReadFile("./data/nt-tokyo.yaml")
if err := yaml.Unmarshal(b, data); err != nil {
panic(err)
cancel(err)
}

//for _, stop := range data.StopRails {
Expand All @@ -63,11 +66,12 @@ func main() {
//}
for _, block := range data.Blocks {
println(block)
err := db.AddBlock(&statev1.BlockState{
err := db.AddBlock(ctx, &statev1.BlockState{
BlockId: string(block),
State: statev1.BlockStateEnum_BLOCK_STATE_OPEN,
})
if err != nil {
cancel(err)
return
}
}
Expand Down
12 changes: 6 additions & 6 deletions backend/state-manager/pkg/connect/connect_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (s *StateManagerServer) GetBlockStates(
ctx context.Context,
req *connect.Request[statev1.GetBlockStatesRequest],
) (*connect.Response[statev1.GetBlockStatesResponse], error) {
blockStates, err := s.DBHandler.GetBlocks()
blockStates, err := s.DBHandler.GetBlocks(ctx)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand Down Expand Up @@ -57,7 +57,7 @@ func (s *StateManagerServer) UpdateBlockState(
ctx context.Context,
req *connect.Request[statev1.UpdateBlockStateRequest],
) (*connect.Response[statev1.UpdateBlockStateResponse], error) {
err := s.DBHandler.UpdateBlock(req.Msg.State)
err := s.DBHandler.UpdateBlock(ctx, req.Msg.State)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand All @@ -77,7 +77,7 @@ func (s *StateManagerServer) UpdatePointState(
ctx context.Context,
req *connect.Request[statev1.UpdatePointStateRequest],
) (*connect.Response[statev1.UpdatePointStateResponse], error) {
err := s.DBHandler.UpdatePoint(req.Msg.State)
err := s.DBHandler.UpdatePoint(ctx, req.Msg.State)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand All @@ -86,7 +86,7 @@ func (s *StateManagerServer) UpdatePointState(
slog.Default().Error("db error", err)
return nil, err
}
s.MqttHandler.NotifyStateUpdate("point", req.Msg.State.Id, req.Msg.State.State.String())
s.MqttHandler.NotifyStateUpdate(ctx, "point", req.Msg.State.Id, req.Msg.State.State.String())

return connect.NewResponse(&statev1.UpdatePointStateResponse{}), nil
}
Expand All @@ -110,7 +110,7 @@ func (s *StateManagerServer) UpdateStopState(
ctx context.Context,
req *connect.Request[statev1.UpdateStopStateRequest],
) (*connect.Response[statev1.UpdateStopStateResponse], error) {
err := s.DBHandler.UpdateStop(req.Msg.State)
err := s.DBHandler.UpdateStop(ctx,req.Msg.State)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
Expand All @@ -119,7 +119,7 @@ func (s *StateManagerServer) UpdateStopState(
slog.Default().Error("db connection error", err)
return nil, err
}
s.MqttHandler.NotifyStateUpdate("stop", req.Msg.State.Id, req.Msg.State.State.String())
s.MqttHandler.NotifyStateUpdate(ctx, "stop", req.Msg.State.Id, req.Msg.State.State.String())
return connect.NewResponse(&statev1.UpdateStopStateResponse{}), nil
}

Expand Down
101 changes: 48 additions & 53 deletions backend/state-manager/pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package db
import (
"context"
"fmt"
"log"
"log/slog"

statev1 "github.com/ueckoken/plarail2023/backend/spec/state/v1"
Expand Down Expand Up @@ -40,12 +39,11 @@ func Open(ctx context.Context, opts *options.ClientOptions) (*DBHandler, error)
}, nil
}

func (db *DBHandler) Close() {
func (db *DBHandler) Close(ctx context.Context) {
slog.Default().Debug("Closing connection to DB...")
// TODO: contextを受けて、その子contextをDBクライアントに渡す
if err := db.stateManagerDB.Client().Disconnect(context.TODO()); err != nil {
if err := db.stateManagerDB.Client().Disconnect(ctx); err != nil {
slog.Default().Error("DB Connection Closing failed")
log.Println(err)
return
}
slog.Default().Debug("DB Connection is successfully closed")
}
Expand All @@ -54,148 +52,145 @@ func (db *DBHandler) Close() {
Point
*/

func (db *DBHandler) UpdatePoint(PointAndState *statev1.PointAndState) error {
func (db *DBHandler) UpdatePoint(ctx context.Context, PointAndState *statev1.PointAndState) error {
collection := db.stateManagerDB.Collection("points")
_, err := collection.UpdateOne(
context.Background(),
ctx,
bson.M{"id": PointAndState.Id},
bson.M{"$set": bson.M{"state": PointAndState.State}},
)
if err != nil {
return err
}
return nil
return fmt.Errorf("update point failed `%w`", err)
}

func (db *DBHandler) AddPoint(PointAndState *statev1.PointAndState) error {
func (db *DBHandler) AddPoint(ctx context.Context, PointAndState *statev1.PointAndState) error {
collection := db.stateManagerDB.Collection("points")
_, err := collection.InsertOne(context.Background(), PointAndState)
_, err := collection.InsertOne(ctx, PointAndState)
if err != nil {
return err
return fmt.Errorf("insert point failed `%w`", err)
}
return nil
}

func (db *DBHandler) GetPoint(pointId string) (*statev1.PointAndState, error) {
func (db *DBHandler) GetPoint(ctx context.Context, pointId string) (*statev1.PointAndState, error) {
collection := db.stateManagerDB.Collection("points")
var result *statev1.PointAndState
err := collection.FindOne(context.Background(), bson.M{"id": pointId}).Decode(&result)
err := collection.FindOne(ctx, bson.M{"id": pointId}).Decode(&result)
if err != nil {
return nil, err
return nil, fmt.Errorf("get point failed `%w`", err)
}
return result, nil
}

func (db *DBHandler) GetPoints() []*statev1.PointAndState {
func (db *DBHandler) GetPoints(ctx context.Context) ([]*statev1.PointAndState, error) {
collection := db.stateManagerDB.Collection("points")
cursor, err := collection.Find(context.Background(), bson.M{})
cursor, err := collection.Find(ctx, bson.M{})
if err != nil {
slog.Default().Warn("Get Points failed", slog.Any("err", err))
panic(err)
return nil, fmt.Errorf("get points failed `%w`", err)
}
var result []*statev1.PointAndState
if err = cursor.All(context.Background(), &result); err != nil {
panic(err)
if err := cursor.All(ctx, &result); err != nil {
slog.Default().Warn("Get Points failed", slog.Any("err", err))
return nil, fmt.Errorf("get points failed `%w`", err)
}
return result
return result, nil
}

/*
Stop
*/

func (db *DBHandler) UpdateStop(stop *statev1.StopAndState) error {
func (db *DBHandler) UpdateStop(ctx context.Context, stop *statev1.StopAndState) error {
collection := db.stateManagerDB.Collection("stops")

_, err := collection.UpdateOne(
context.Background(),
ctx,
bson.M{"id": stop.Id},
bson.M{"$set": bson.M{"state": stop.State}},
)

if err != nil {
return err
return fmt.Errorf("update stop failed `%w`", err)
}
return nil
}

func (db *DBHandler) AddStop(stop *statev1.StopAndState) error {
func (db *DBHandler) AddStop(ctx context.Context, stop *statev1.StopAndState) error {
collection := db.stateManagerDB.Collection("stops")
_, err := collection.InsertOne(context.Background(), stop)
_, err := collection.InsertOne(ctx, stop)
if err != nil {
return err
return fmt.Errorf("insert stop failed `%w`", err)
}
return nil
}

func (db *DBHandler) GetStop(stopId string) (*statev1.StopAndState, error) {
func (db *DBHandler) GetStop(ctx context.Context, stopId string) (*statev1.StopAndState, error) {
collection := db.stateManagerDB.Collection("stops")
var result *statev1.StopAndState
err := collection.FindOne(context.Background(), bson.M{"id": stopId}).Decode(&result)
err := collection.FindOne(ctx, bson.M{"id": stopId}).Decode(&result)
if err != nil {
return nil, err
return nil, fmt.Errorf("get stop failed `%w`", err)
}
return result, nil
}

func (db *DBHandler) GetStops() []*statev1.StopAndState {
func (db *DBHandler) GetStops(ctx context.Context) ([]*statev1.StopAndState, error) {
collection := db.stateManagerDB.Collection("stops")
cursor, err := collection.Find(context.Background(), bson.M{})
cursor, err := collection.Find(ctx, bson.M{})
if err != nil {
panic(err)
return nil, fmt.Errorf("get stops failed `%w`", err)
}
var result []*statev1.StopAndState
if err = cursor.All(context.Background(), &result); err != nil {
panic(err)
if err := cursor.All(ctx, &result); err != nil {
return nil, fmt.Errorf("get stops failed `%w`", err)
}
return result
return result, nil
}

/*
Block
*/

func (db *DBHandler) AddBlock(block *statev1.BlockState) error {
func (db *DBHandler) AddBlock(ctx context.Context, block *statev1.BlockState) error {
collection := db.stateManagerDB.Collection("blocks")
_, err := collection.InsertOne(context.Background(), block)
_, err := collection.InsertOne(ctx, block)
if err != nil {
return err
return fmt.Errorf("insert block failed `%w`", err)
}
return nil
}

func (db *DBHandler) UpdateBlock(block *statev1.BlockState) error {
func (db *DBHandler) UpdateBlock(ctx context.Context, block *statev1.BlockState) error {
collection := db.stateManagerDB.Collection("blocks")
_, err := collection.UpdateOne(
context.Background(),
ctx,
bson.M{"blockid": block.BlockId},
bson.M{"$set": bson.M{"state": block.State}},
)
if err != nil {
return err
return fmt.Errorf("update block failed `%w`", err)
}
return nil
}

func (db *DBHandler) GetBlock(blockId string) (*statev1.BlockState, error) {
func (db *DBHandler) GetBlock(ctx context.Context, blockId string) (*statev1.BlockState, error) {
collection := db.stateManagerDB.Collection("blocks")
var result *statev1.BlockState
err := collection.FindOne(context.Background(), bson.M{"blockid": blockId}).Decode(&result)
err := collection.FindOne(ctx, bson.M{"blockid": blockId}).Decode(&result)
if err != nil {
return nil, err
return nil, fmt.Errorf("get block failed `%w`", err)
}
return result, nil
}

func (db *DBHandler) GetBlocks() ([]*statev1.BlockState, error) {
func (db *DBHandler) GetBlocks(ctx context.Context) ([]*statev1.BlockState, error) {
collection := db.stateManagerDB.Collection("blocks")
cursor, err := collection.Find(context.Background(), bson.M{})
cursor, err := collection.Find(ctx, bson.M{})
if err != nil {
return nil, err
return nil, fmt.Errorf("get blocks failed `%w`", err)
}
var result []*statev1.BlockState
if err = cursor.All(context.Background(), &result); err != nil {
return nil, err
if err = cursor.All(ctx, &result); err != nil {
return nil, fmt.Errorf("get blocks failed `%w`", err)
}
return result, nil
}
Loading

0 comments on commit 67d8832

Please sign in to comment.