handle context cancelation
This commit is contained in:
@@ -1,23 +1,31 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
|
||||||
"git.naterciomoniz.net/applications/broker2anexoj/internal"
|
"git.naterciomoniz.net/applications/broker2anexoj/internal"
|
||||||
"git.naterciomoniz.net/applications/broker2anexoj/internal/trading212"
|
"git.naterciomoniz.net/applications/broker2anexoj/internal/trading212"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
err := run()
|
err := run(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("found a fatal issue", slog.Any("err", err))
|
slog.Error("found a fatal issue", slog.Any("err", err))
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func run() error {
|
func run(ctx context.Context) error {
|
||||||
|
ctx, cancel := signal.NotifyContext(ctx, os.Kill, os.Interrupt)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
f, err := os.Open("test.csv")
|
f, err := os.Open("test.csv")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("open statement: %w", err)
|
return fmt.Errorf("open statement: %w", err)
|
||||||
@@ -27,7 +35,11 @@ func run() error {
|
|||||||
|
|
||||||
writer := internal.NewStdOutLogger()
|
writer := internal.NewStdOutLogger()
|
||||||
|
|
||||||
err = internal.BuildReport(reader, writer)
|
eg.Go(func() error {
|
||||||
|
return internal.BuildReport(ctx, reader, writer)
|
||||||
|
})
|
||||||
|
|
||||||
|
err = eg.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -10,40 +11,45 @@ import (
|
|||||||
|
|
||||||
type RecordReader interface {
|
type RecordReader interface {
|
||||||
// ReadRecord should return Records until an error is found.
|
// ReadRecord should return Records until an error is found.
|
||||||
ReadRecord() (Record, error)
|
ReadRecord(context.Context) (Record, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReportWriter interface {
|
type ReportWriter interface {
|
||||||
// ReportWriter writes report items
|
// ReportWriter writes report items
|
||||||
Write(ReportItem) error
|
Write(context.Context, ReportItem) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildReport(reader RecordReader, writer ReportWriter) error {
|
func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter) error {
|
||||||
buys := make(map[string]*RecordQueue)
|
buys := make(map[string]*RecordQueue)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
rec, err := reader.ReadRecord()
|
select {
|
||||||
if err != nil {
|
case <-ctx.Done():
|
||||||
if errors.Is(err, io.EOF) {
|
return ctx.Err()
|
||||||
return nil
|
default:
|
||||||
|
rec, err := reader.ReadRecord(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
buyQueue, ok := buys[rec.Symbol()]
|
buyQueue, ok := buys[rec.Symbol()]
|
||||||
if !ok {
|
if !ok {
|
||||||
buyQueue = new(RecordQueue)
|
buyQueue = new(RecordQueue)
|
||||||
buys[rec.Symbol()] = buyQueue
|
buys[rec.Symbol()] = buyQueue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = processRecord(buyQueue, rec, writer)
|
err = processRecord(ctx, buyQueue, rec, writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("processing record: %w", err)
|
return fmt.Errorf("processing record: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRecord(q *RecordQueue, rec Record, writer ReportWriter) error {
|
func processRecord(ctx context.Context, q *RecordQueue, rec Record, writer ReportWriter) error {
|
||||||
switch rec.Side() {
|
switch rec.Side() {
|
||||||
case SideBuy:
|
case SideBuy:
|
||||||
q.Push(rec)
|
q.Push(rec)
|
||||||
@@ -72,7 +78,7 @@ func processRecord(q *RecordQueue, rec Record, writer ReportWriter) error {
|
|||||||
sellValue := new(big.Float).Mul(matchedQty, rec.Price())
|
sellValue := new(big.Float).Mul(matchedQty, rec.Price())
|
||||||
buyValue := new(big.Float).Mul(matchedQty, buy.Price())
|
buyValue := new(big.Float).Mul(matchedQty, buy.Price())
|
||||||
|
|
||||||
err := writer.Write(ReportItem{
|
err := writer.Write(ctx, ReportItem{
|
||||||
BuyValue: buyValue,
|
BuyValue: buyValue,
|
||||||
BuyTimestamp: buy.Timestamp(),
|
BuyTimestamp: buy.Timestamp(),
|
||||||
SellValue: sellValue,
|
SellValue: sellValue,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package internal_test
|
package internal_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -20,7 +21,7 @@ func TestReporter_Run(t *testing.T) {
|
|||||||
mockRecord(ctrl, 20.0, 10.0, internal.SideBuy, now),
|
mockRecord(ctrl, 20.0, 10.0, internal.SideBuy, now),
|
||||||
mockRecord(ctrl, 25.0, 10.0, internal.SideSell, now.Add(1)),
|
mockRecord(ctrl, 25.0, 10.0, internal.SideSell, now.Add(1)),
|
||||||
}
|
}
|
||||||
reader.EXPECT().ReadRecord().DoAndReturn(func() (internal.Record, error) {
|
reader.EXPECT().ReadRecord(gomock.Any()).DoAndReturn(func(ctx context.Context) (internal.Record, error) {
|
||||||
if len(records) > 0 {
|
if len(records) > 0 {
|
||||||
r := records[0]
|
r := records[0]
|
||||||
records = records[1:]
|
records = records[1:]
|
||||||
@@ -31,7 +32,7 @@ func TestReporter_Run(t *testing.T) {
|
|||||||
}).Times(3)
|
}).Times(3)
|
||||||
|
|
||||||
writer := mocks.NewMockReportWriter(ctrl)
|
writer := mocks.NewMockReportWriter(ctrl)
|
||||||
writer.EXPECT().Write(gomock.Eq(internal.ReportItem{
|
writer.EXPECT().Write(gomock.Any(), gomock.Eq(internal.ReportItem{
|
||||||
BuyValue: new(big.Float).SetFloat64(200.0),
|
BuyValue: new(big.Float).SetFloat64(200.0),
|
||||||
BuyTimestamp: now,
|
BuyTimestamp: now,
|
||||||
SellValue: new(big.Float).SetFloat64(250.0),
|
SellValue: new(big.Float).SetFloat64(250.0),
|
||||||
@@ -40,7 +41,7 @@ func TestReporter_Run(t *testing.T) {
|
|||||||
Taxes: new(big.Float),
|
Taxes: new(big.Float),
|
||||||
})).Times(1)
|
})).Times(1)
|
||||||
|
|
||||||
gotErr := internal.BuildReport(reader, writer)
|
gotErr := internal.BuildReport(t.Context(), reader, writer)
|
||||||
if gotErr != nil {
|
if gotErr != nil {
|
||||||
t.Fatalf("got unexpected err: %v", gotErr)
|
t.Fatalf("got unexpected err: %v", gotErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
@@ -24,7 +25,7 @@ func NewReportLogger(w io.Writer) *ReportLogger {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl *ReportLogger) Write(ri ReportItem) error {
|
func (rl *ReportLogger) Write(_ context.Context, ri ReportItem) error {
|
||||||
rl.counter++
|
rl.counter++
|
||||||
_, err := fmt.Fprintf(rl.writer, "%6d - realised %+f on %s\n", rl.counter, ri.RealisedPnL(), ri.SellTimestamp.Format(time.RFC3339))
|
_, err := fmt.Fprintf(rl.writer, "%6d - realised %+f on %s\n", rl.counter, ri.RealisedPnL(), ri.SellTimestamp.Format(time.RFC3339))
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ func TestReportLogger_Write(t *testing.T) {
|
|||||||
rw := internal.NewReportLogger(buf)
|
rw := internal.NewReportLogger(buf)
|
||||||
|
|
||||||
for _, item := range tt.items {
|
for _, item := range tt.items {
|
||||||
err := rw.Write(item)
|
err := rw.Write(t.Context(), item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error on write: %v", err)
|
t.Fatalf("unexpected error on write: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package trading212
|
package trading212
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -66,7 +67,7 @@ const (
|
|||||||
LimitSell = "limit sell"
|
LimitSell = "limit sell"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (rr RecordReader) ReadRecord() (internal.Record, error) {
|
func (rr RecordReader) ReadRecord(_ context.Context) (internal.Record, error) {
|
||||||
for {
|
for {
|
||||||
raw, err := rr.reader.Read()
|
raw, err := rr.reader.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
rr := NewRecordReader(tt.r)
|
rr := NewRecordReader(tt.r)
|
||||||
got, gotErr := rr.ReadRecord()
|
got, gotErr := rr.ReadRecord(t.Context())
|
||||||
if gotErr != nil {
|
if gotErr != nil {
|
||||||
if !tt.wantErr {
|
if !tt.wantErr {
|
||||||
t.Fatalf("ReadRecord() failed: %v", gotErr)
|
t.Fatalf("ReadRecord() failed: %v", gotErr)
|
||||||
|
|||||||
Reference in New Issue
Block a user