handle context cancelation

This commit is contained in:
2025-11-13 16:05:16 +00:00
parent 54fced39aa
commit f3d0f5d71a
7 changed files with 50 additions and 29 deletions

View File

@@ -1,23 +1,31 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
"os/signal"
"git.naterciomoniz.net/applications/broker2anexoj/internal" "git.naterciomoniz.net/applications/broker2anexoj/internal"
"git.naterciomoniz.net/applications/broker2anexoj/internal/trading212" "git.naterciomoniz.net/applications/broker2anexoj/internal/trading212"
"golang.org/x/sync/errgroup"
) )
func main() { func main() {
err := run() err := run(context.Background())
if err != nil { if err != nil {
slog.Error("found a fatal issue", slog.Any("err", err)) slog.Error("found a fatal issue", slog.Any("err", err))
os.Exit(1) os.Exit(1)
} }
} }
func run() error { func run(ctx context.Context) error {
ctx, cancel := signal.NotifyContext(ctx, os.Kill, os.Interrupt)
defer cancel()
eg, ctx := errgroup.WithContext(ctx)
f, err := os.Open("test.csv") f, err := os.Open("test.csv")
if err != nil { if err != nil {
return fmt.Errorf("open statement: %w", err) return fmt.Errorf("open statement: %w", err)
@@ -27,7 +35,11 @@ func run() error {
writer := internal.NewStdOutLogger() writer := internal.NewStdOutLogger()
err = internal.BuildReport(reader, writer) eg.Go(func() error {
return internal.BuildReport(ctx, reader, writer)
})
err = eg.Wait()
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -10,40 +11,45 @@ import (
type RecordReader interface { type RecordReader interface {
// ReadRecord should return Records until an error is found. // ReadRecord should return Records until an error is found.
ReadRecord() (Record, error) ReadRecord(context.Context) (Record, error)
} }
type ReportWriter interface { type ReportWriter interface {
// ReportWriter writes report items // ReportWriter writes report items
Write(ReportItem) error Write(context.Context, ReportItem) error
} }
func BuildReport(reader RecordReader, writer ReportWriter) error { func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter) error {
buys := make(map[string]*RecordQueue) buys := make(map[string]*RecordQueue)
for { for {
rec, err := reader.ReadRecord() select {
if err != nil { case <-ctx.Done():
if errors.Is(err, io.EOF) { return ctx.Err()
return nil default:
rec, err := reader.ReadRecord(ctx)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
} }
return err
}
buyQueue, ok := buys[rec.Symbol()] buyQueue, ok := buys[rec.Symbol()]
if !ok { if !ok {
buyQueue = new(RecordQueue) buyQueue = new(RecordQueue)
buys[rec.Symbol()] = buyQueue buys[rec.Symbol()] = buyQueue
} }
err = processRecord(buyQueue, rec, writer) err = processRecord(ctx, buyQueue, rec, writer)
if err != nil { if err != nil {
return fmt.Errorf("processing record: %w", err) return fmt.Errorf("processing record: %w", err)
}
} }
} }
} }
func processRecord(q *RecordQueue, rec Record, writer ReportWriter) error { func processRecord(ctx context.Context, q *RecordQueue, rec Record, writer ReportWriter) error {
switch rec.Side() { switch rec.Side() {
case SideBuy: case SideBuy:
q.Push(rec) q.Push(rec)
@@ -72,7 +78,7 @@ func processRecord(q *RecordQueue, rec Record, writer ReportWriter) error {
sellValue := new(big.Float).Mul(matchedQty, rec.Price()) sellValue := new(big.Float).Mul(matchedQty, rec.Price())
buyValue := new(big.Float).Mul(matchedQty, buy.Price()) buyValue := new(big.Float).Mul(matchedQty, buy.Price())
err := writer.Write(ReportItem{ err := writer.Write(ctx, ReportItem{
BuyValue: buyValue, BuyValue: buyValue,
BuyTimestamp: buy.Timestamp(), BuyTimestamp: buy.Timestamp(),
SellValue: sellValue, SellValue: sellValue,

View File

@@ -1,6 +1,7 @@
package internal_test package internal_test
import ( import (
"context"
"io" "io"
"math/big" "math/big"
"testing" "testing"
@@ -20,7 +21,7 @@ func TestReporter_Run(t *testing.T) {
mockRecord(ctrl, 20.0, 10.0, internal.SideBuy, now), mockRecord(ctrl, 20.0, 10.0, internal.SideBuy, now),
mockRecord(ctrl, 25.0, 10.0, internal.SideSell, now.Add(1)), mockRecord(ctrl, 25.0, 10.0, internal.SideSell, now.Add(1)),
} }
reader.EXPECT().ReadRecord().DoAndReturn(func() (internal.Record, error) { reader.EXPECT().ReadRecord(gomock.Any()).DoAndReturn(func(ctx context.Context) (internal.Record, error) {
if len(records) > 0 { if len(records) > 0 {
r := records[0] r := records[0]
records = records[1:] records = records[1:]
@@ -31,7 +32,7 @@ func TestReporter_Run(t *testing.T) {
}).Times(3) }).Times(3)
writer := mocks.NewMockReportWriter(ctrl) writer := mocks.NewMockReportWriter(ctrl)
writer.EXPECT().Write(gomock.Eq(internal.ReportItem{ writer.EXPECT().Write(gomock.Any(), gomock.Eq(internal.ReportItem{
BuyValue: new(big.Float).SetFloat64(200.0), BuyValue: new(big.Float).SetFloat64(200.0),
BuyTimestamp: now, BuyTimestamp: now,
SellValue: new(big.Float).SetFloat64(250.0), SellValue: new(big.Float).SetFloat64(250.0),
@@ -40,7 +41,7 @@ func TestReporter_Run(t *testing.T) {
Taxes: new(big.Float), Taxes: new(big.Float),
})).Times(1) })).Times(1)
gotErr := internal.BuildReport(reader, writer) gotErr := internal.BuildReport(t.Context(), reader, writer)
if gotErr != nil { if gotErr != nil {
t.Fatalf("got unexpected err: %v", gotErr) t.Fatalf("got unexpected err: %v", gotErr)
} }

View File

@@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -24,7 +25,7 @@ func NewReportLogger(w io.Writer) *ReportLogger {
} }
} }
func (rl *ReportLogger) Write(ri ReportItem) error { func (rl *ReportLogger) Write(_ context.Context, ri ReportItem) error {
rl.counter++ rl.counter++
_, err := fmt.Fprintf(rl.writer, "%6d - realised %+f on %s\n", rl.counter, ri.RealisedPnL(), ri.SellTimestamp.Format(time.RFC3339)) _, err := fmt.Fprintf(rl.writer, "%6d - realised %+f on %s\n", rl.counter, ri.RealisedPnL(), ri.SellTimestamp.Format(time.RFC3339))
return err return err

View File

@@ -73,7 +73,7 @@ func TestReportLogger_Write(t *testing.T) {
rw := internal.NewReportLogger(buf) rw := internal.NewReportLogger(buf)
for _, item := range tt.items { for _, item := range tt.items {
err := rw.Write(item) err := rw.Write(t.Context(), item)
if err != nil { if err != nil {
t.Fatalf("unexpected error on write: %v", err) t.Fatalf("unexpected error on write: %v", err)
} }

View File

@@ -1,6 +1,7 @@
package trading212 package trading212
import ( import (
"context"
"encoding/csv" "encoding/csv"
"fmt" "fmt"
"io" "io"
@@ -66,7 +67,7 @@ const (
LimitSell = "limit sell" LimitSell = "limit sell"
) )
func (rr RecordReader) ReadRecord() (internal.Record, error) { func (rr RecordReader) ReadRecord(_ context.Context) (internal.Record, error) {
for { for {
raw, err := rr.reader.Read() raw, err := rr.reader.Read()
if err != nil { if err != nil {

View File

@@ -93,7 +93,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
rr := NewRecordReader(tt.r) rr := NewRecordReader(tt.r)
got, gotErr := rr.ReadRecord() got, gotErr := rr.ReadRecord(t.Context())
if gotErr != nil { if gotErr != nil {
if !tt.wantErr { if !tt.wantErr {
t.Fatalf("ReadRecord() failed: %v", gotErr) t.Fatalf("ReadRecord() failed: %v", gotErr)