Merge pull request 'Several improvements and bug fixes for 2025 tax return' (#25) from fix-bugs-for-2025-report into main
All checks were successful
Badges / coveralls (push) Successful in 14s

Reviewed-on: #25
This commit was merged in pull request #25.
This commit is contained in:
2026-05-17 09:47:28 +01:00
15 changed files with 604 additions and 194 deletions

View File

@@ -0,0 +1,57 @@
package main
import (
"encoding/csv"
"fmt"
"io"
"github.com/nmoniz/any2anexoj/internal"
)
type CSVWriter struct {
w *csv.Writer
}
func NewCSVWriter(w io.Writer) *CSVWriter {
return &CSVWriter{w: csv.NewWriter(w)}
}
func (cw *CSVWriter) Render(aw *internal.AggregatorWriter) error {
err := cw.w.Write([]string{
"source_country", "code",
"realization_year", "realization_month", "realization_day", "realization_value",
"acquisition_year", "acquisition_month", "acquisition_day", "acquisition_value",
"expenses", "foreign_tax_paid", "counter_country",
})
if err != nil {
return fmt.Errorf("write csv header: %w", err)
}
for ri := range aw.Iter() {
err := cw.w.Write(reportItemToRow(ri))
if err != nil {
return fmt.Errorf("write csv row: %w", err)
}
}
cw.w.Flush()
return cw.w.Error()
}
func reportItemToRow(ri internal.ReportItem) []string {
return []string{
fmt.Sprintf("%d", ri.AssetCountry),
string(ri.Nature),
fmt.Sprintf("%d", ri.SellTimestamp.Year()),
fmt.Sprintf("%d", int(ri.SellTimestamp.Month())),
fmt.Sprintf("%d", ri.SellTimestamp.Day()),
ri.SellValue.StringFixed(2),
fmt.Sprintf("%d", ri.BuyTimestamp.Year()),
fmt.Sprintf("%d", int(ri.BuyTimestamp.Month())),
fmt.Sprintf("%d", ri.BuyTimestamp.Day()),
ri.BuyValue.StringFixed(2),
ri.Fees.StringFixed(2),
ri.Taxes.StringFixed(2),
fmt.Sprintf("%d", ri.BrokerCountry),
}
}

View File

@@ -21,14 +21,11 @@ var (
// remove/change default // remove/change default
platform = pflag.StringP("platform", "p", "trading212", "One of the supported platforms") platform = pflag.StringP("platform", "p", "trading212", "One of the supported platforms")
lang = pflag.StringP("language", "l", language.Portuguese.String(), "The 2 letter language code") lang = pflag.StringP("language", "l", language.Portuguese.String(), "The 2 letter language code")
debug = pflag.BoolP("debug", "d", false, "Activate to log debug messages")
format = pflag.StringP("format", "f", "table", "Output format: table or csv")
ofAPIKey = pflag.String("open-figi-api-key", "", "An OpenFIGI API key for faster report generation (better rate api rate limits)")
// TODO: improve documentation on selectors // TODO: improve documentation on selectors
selectors = pflag.StringSlice("selectors", nil, "Only process entries that conform to all the selectors:") selectors = pflag.StringSlice("selectors", nil, "Only process entries that conform to all the selectors:")
readerFactories = map[string]func() internal.RecordReader{
"trading212": func() internal.RecordReader {
return trading212.NewRecordReader(os.Stdin, internal.NewOpenFIGI(&http.Client{Timeout: 5 * time.Second}))
},
}
) )
func main() { func main() {
@@ -42,6 +39,17 @@ func main() {
} }
func run(ctx context.Context) error { func run(ctx context.Context) error {
ctx, cancel := signal.NotifyContext(ctx, os.Kill, os.Interrupt)
defer cancel()
eg, ctx := errgroup.WithContext(ctx)
logLevel := slog.LevelInfo
if *debug {
logLevel = slog.LevelDebug
}
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})))
if platform == nil || len(*platform) == 0 { if platform == nil || len(*platform) == 0 {
slog.Error("--platform flag is required") slog.Error("--platform flag is required")
os.Exit(1) os.Exit(1)
@@ -52,20 +60,11 @@ func run(ctx context.Context) error {
os.Exit(1) os.Exit(1)
} }
ctx, cancel := signal.NotifyContext(ctx, os.Kill, os.Interrupt) reader, err := getReader(*platform, *ofAPIKey)
defer cancel() if err != nil {
return fmt.Errorf("getting reader: %w", err)
eg, ctx := errgroup.WithContext(ctx)
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
factory, ok := readerFactories[*platform]
if !ok {
return fmt.Errorf("unsupported platform: %s", *platform)
} }
reader := factory()
writer := internal.NewAggregatorWriter() writer := internal.NewAggregatorWriter()
selector, err := internal.ParseSelectors(*selectors) selector, err := internal.ParseSelectors(*selectors)
@@ -82,14 +81,26 @@ func run(ctx context.Context) error {
return err return err
} }
switch *format {
case "csv":
return NewCSVWriter(os.Stdout).Render(writer)
case "table":
loc, err := NewLocalizer(*lang) loc, err := NewLocalizer(*lang)
if err != nil { if err != nil {
return fmt.Errorf("create localizer: %w", err) return fmt.Errorf("create localizer: %w", err)
} }
NewPrettyPrinter(os.Stdout, loc).Render(writer)
printer := NewPrettyPrinter(os.Stdout, loc)
printer.Render(writer)
return nil return nil
default:
return fmt.Errorf("unsupported format %q: must be table or csv", *format)
}
}
func getReader(platform string, ofAPIKey string) (internal.RecordReader, error) {
switch platform {
case "trading212":
return trading212.NewRecordReader(os.Stdin, internal.NewOpenFIGI(&http.Client{Timeout: 5 * time.Second}, ofAPIKey)), nil
default:
return nil, fmt.Errorf("unsupported platform: %s", platform)
}
} }

View File

@@ -10,18 +10,26 @@ type Filler struct {
Record Record
filled decimal.Decimal filled decimal.Decimal
quantity decimal.Decimal
price decimal.Decimal
} }
func NewFiller(r Record) *Filler { func NewFiller(r Record) *Filler {
return &Filler{ return &Filler{
Record: r, Record: r,
quantity: r.Quantity(),
price: r.Price(),
} }
} }
func (f *Filler) Quantity() decimal.Decimal { return f.quantity }
func (f *Filler) Price() decimal.Decimal { return f.price }
// Fill accrues some quantity. Returns how mutch was accrued in the 1st return value and whether // Fill accrues some quantity. Returns how mutch was accrued in the 1st return value and whether
// it was filled or not on the 2nd return value. // it was filled or not on the 2nd return value.
func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) { func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) {
unfilled := f.Record.Quantity().Sub(f.filled) unfilled := f.quantity.Sub(f.filled)
delta := decimal.Min(unfilled, quantity) delta := decimal.Min(unfilled, quantity)
f.filled = f.filled.Add(delta) f.filled = f.filled.Add(delta)
return delta, f.IsFilled() return delta, f.IsFilled()
@@ -29,7 +37,15 @@ func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) {
// IsFilled returns true if the fill is equal to the record quantity. // IsFilled returns true if the fill is equal to the record quantity.
func (f *Filler) IsFilled() bool { func (f *Filler) IsFilled() bool {
return f.filled.Equal(f.Quantity()) return f.filled.Equal(f.quantity)
}
// ApplySplit adjusts the lot for a stock split by the given ratio (newQty/oldQty).
// The total cost basis is preserved: quantity scales up, price scales down proportionally.
func (f *Filler) ApplySplit(ratio decimal.Decimal) {
f.quantity = f.quantity.Mul(ratio)
f.filled = f.filled.Mul(ratio)
f.price = f.price.Div(ratio)
} }
type FillerQueue struct { type FillerQueue struct {
@@ -86,6 +102,16 @@ func (fq *FillerQueue) frontElement() *list.Element {
return fq.l.Front() return fq.l.Front()
} }
// AdjustForSplit applies a stock split ratio to all lots in the queue.
func (fq *FillerQueue) AdjustForSplit(ratio decimal.Decimal) {
if fq == nil || fq.l == nil {
return
}
for e := fq.l.Front(); e != nil; e = e.Next() {
e.Value.(*Filler).ApplySplit(ratio)
}
}
// Len returns how many elements are currently on the queue // Len returns how many elements are currently on the queue
func (fq *FillerQueue) Len() int { func (fq *FillerQueue) Len() int {
if fq == nil || fq.l == nil { if fq == nil || fq.l == nil {

View File

@@ -114,7 +114,7 @@ func TestFillerQueueNilReceiver(t *testing.T) {
t.Fatalf(`want panic message %q but got "%v"`, expMsg, r) t.Fatalf(`want panic message %q but got "%v"`, expMsg, r)
} }
}() }()
rq.Push(NewFiller(nil)) rq.Push(NewFiller(&testRecord{}))
} }
type testRecord struct { type testRecord struct {
@@ -122,11 +122,11 @@ type testRecord struct {
id int id int
quantity decimal.Decimal quantity decimal.Decimal
price decimal.Decimal
} }
func (tr testRecord) Quantity() decimal.Decimal { func (tr testRecord) Quantity() decimal.Decimal { return tr.quantity }
return tr.quantity func (tr testRecord) Price() decimal.Decimal { return tr.price }
}
func TestFiller_Fill(t *testing.T) { func TestFiller_Fill(t *testing.T) {
tests := []struct { tests := []struct {
@@ -185,3 +185,89 @@ func TestFiller_Fill(t *testing.T) {
}) })
} }
} }
func TestFiller_ApplySplit(t *testing.T) {
tests := []struct {
name string
qty float64
price float64
prefilled float64
ratio float64
wantQty float64
wantPrice float64
wantFilled float64
wantCostBasis float64
}{
{
name: "5:1 split on unfilled lot preserves cost basis",
qty: 10, price: 100, prefilled: 0, ratio: 5,
wantQty: 50, wantPrice: 20, wantFilled: 0, wantCostBasis: 1000,
},
{
name: "5:1 split on partially filled lot",
qty: 10, price: 100, prefilled: 4, ratio: 5,
wantQty: 50, wantPrice: 20, wantFilled: 20, wantCostBasis: 1000,
},
{
name: "1:2 reverse split on unfilled lot preserves cost basis",
qty: 10, price: 100, prefilled: 0, ratio: 0.5,
wantQty: 5, wantPrice: 200, wantFilled: 0, wantCostBasis: 1000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := NewFiller(&testRecord{
quantity: decimal.NewFromFloat(tt.qty),
price: decimal.NewFromFloat(tt.price),
})
if tt.prefilled > 0 {
f.Fill(decimal.NewFromFloat(tt.prefilled))
}
f.ApplySplit(decimal.NewFromFloat(tt.ratio))
if !f.Quantity().Equal(decimal.NewFromFloat(tt.wantQty)) {
t.Errorf("want quantity %v but got %v", tt.wantQty, f.Quantity())
}
if !f.Price().Equal(decimal.NewFromFloat(tt.wantPrice)) {
t.Errorf("want price %v but got %v", tt.wantPrice, f.Price())
}
if !f.filled.Equal(decimal.NewFromFloat(tt.wantFilled)) {
t.Errorf("want filled %v but got %v", tt.wantFilled, f.filled)
}
costBasis := f.Quantity().Mul(f.Price())
if !costBasis.Equal(decimal.NewFromFloat(tt.wantCostBasis)) {
t.Errorf("want cost basis %v but got %v", tt.wantCostBasis, costBasis)
}
})
}
}
func TestFillerQueue_AdjustForSplit(t *testing.T) {
var fq FillerQueue
fq.Push(NewFiller(&testRecord{quantity: decimal.NewFromFloat(10), price: decimal.NewFromFloat(100)}))
fq.Push(NewFiller(&testRecord{quantity: decimal.NewFromFloat(5), price: decimal.NewFromFloat(200)}))
fq.AdjustForSplit(decimal.NewFromFloat(5))
lot1, _ := fq.Pop()
if !lot1.Quantity().Equal(decimal.NewFromFloat(50)) {
t.Errorf("lot1: want quantity 50 but got %v", lot1.Quantity())
}
if !lot1.Price().Equal(decimal.NewFromFloat(20)) {
t.Errorf("lot1: want price 20 but got %v", lot1.Price())
}
lot2, _ := fq.Pop()
if !lot2.Quantity().Equal(decimal.NewFromFloat(25)) {
t.Errorf("lot2: want quantity 25 but got %v", lot2.Quantity())
}
if !lot2.Price().Equal(decimal.NewFromFloat(40)) {
t.Errorf("lot2: want price 40 but got %v", lot2.Price())
}
}
func TestFillerQueue_AdjustForSplit_NilReceiver(t *testing.T) {
var fq *FillerQueue
fq.AdjustForSplit(decimal.NewFromFloat(5)) // must not panic
}

30
internal/kind.go Normal file
View File

@@ -0,0 +1,30 @@
package internal
type Kind uint
const (
KindUnknown Kind = iota
KindBuy
KindSell
KindSplit
)
// String returns a human readable value
func (d Kind) String() string {
switch d {
case KindBuy:
return "buy"
case KindSell:
return "sell"
case KindSplit:
return "split"
default:
return "unknown"
}
}
// Is returns true when k equals o
func (k Kind) Is(o any) bool {
other, ok := o.(Kind)
return ok && k == other
}

View File

@@ -5,12 +5,12 @@ import "testing"
func TestSide_String(t *testing.T) { func TestSide_String(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
side Side side Kind
want string want string
}{ }{
{"buy", SideBuy, "buy"}, {"buy", KindBuy, "buy"},
{"sell", SideSell, "sell"}, {"sell", KindSell, "sell"},
{"unknown", SideUnknown, "unknown"}, {"unknown", KindUnknown, "unknown"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@@ -24,16 +24,16 @@ func TestSide_String(t *testing.T) {
func TestSide_IsBuy(t *testing.T) { func TestSide_IsBuy(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
side Side side Kind
want bool want bool
}{ }{
{"buy", SideBuy, true}, {"buy", KindBuy, true},
{"sell", SideSell, false}, {"sell", KindSell, false},
{"unknown", SideUnknown, false}, {"unknown", KindUnknown, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := tt.side.IsBuy(); got != tt.want { if got := tt.side.Is(KindBuy); got != tt.want {
t.Errorf("want Side.IsBuy() to be %v but got %v", tt.want, got) t.Errorf("want Side.IsBuy() to be %v but got %v", tt.want, got)
} }
}) })
@@ -43,16 +43,16 @@ func TestSide_IsBuy(t *testing.T) {
func TestSide_IsSell(t *testing.T) { func TestSide_IsSell(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
side Side side Kind
want bool want bool
}{ }{
{"buy", SideBuy, false}, {"buy", KindBuy, false},
{"sell", SideSell, true}, {"sell", KindSell, true},
{"unknown", SideUnknown, false}, {"unknown", KindUnknown, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := tt.side.IsSell(); got != tt.want { if got := tt.side.Is(KindSell); got != tt.want {
t.Errorf("want Side.IsSell() to be %v but got %v", tt.want, got) t.Errorf("want Side.IsSell() to be %v but got %v", tt.want, got)
} }
}) })

View File

@@ -220,6 +220,44 @@ func (c *MockRecordFeesCall) DoAndReturn(f func() decimal.Decimal) *MockRecordFe
return c return c
} }
// Kind mocks base method.
func (m *MockRecord) Kind() internal.Kind {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Kind")
ret0, _ := ret[0].(internal.Kind)
return ret0
}
// Kind indicates an expected call of Kind.
func (mr *MockRecordMockRecorder) Kind() *MockRecordKindCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kind", reflect.TypeOf((*MockRecord)(nil).Kind))
return &MockRecordKindCall{Call: call}
}
// MockRecordKindCall wrap *gomock.Call
type MockRecordKindCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockRecordKindCall) Return(arg0 internal.Kind) *MockRecordKindCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockRecordKindCall) Do(f func() internal.Kind) *MockRecordKindCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockRecordKindCall) DoAndReturn(f func() internal.Kind) *MockRecordKindCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Nature mocks base method. // Nature mocks base method.
func (m *MockRecord) Nature() internal.Nature { func (m *MockRecord) Nature() internal.Nature {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -334,44 +372,6 @@ func (c *MockRecordQuantityCall) DoAndReturn(f func() decimal.Decimal) *MockReco
return c return c
} }
// Side mocks base method.
func (m *MockRecord) Side() internal.Side {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Side")
ret0, _ := ret[0].(internal.Side)
return ret0
}
// Side indicates an expected call of Side.
func (mr *MockRecordMockRecorder) Side() *MockRecordSideCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Side", reflect.TypeOf((*MockRecord)(nil).Side))
return &MockRecordSideCall{Call: call}
}
// MockRecordSideCall wrap *gomock.Call
type MockRecordSideCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockRecordSideCall) Return(arg0 internal.Side) *MockRecordSideCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockRecordSideCall) Do(f func() internal.Side) *MockRecordSideCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockRecordSideCall) DoAndReturn(f func() internal.Side) *MockRecordSideCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Symbol mocks base method. // Symbol mocks base method.
func (m *MockRecord) Symbol() string { func (m *MockRecord) Symbol() string {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@@ -13,9 +14,12 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
// OpenFIGI is a small adapter for the openfigi.com api var OpenFIGIAPIKeyHeader = http.CanonicalHeaderKey("X-OPENFIGI-APIKEY")
// OpenFIGI is a small adapter for the openfigi.com api.
type OpenFIGI struct { type OpenFIGI struct {
client *http.Client client *http.Client
apiKey string
mappingLimiter *rate.Limiter mappingLimiter *rate.Limiter
mu sync.RWMutex mu sync.RWMutex
@@ -25,11 +29,21 @@ type OpenFIGI struct {
securityTypeCache map[string]string securityTypeCache map[string]string
} }
func NewOpenFIGI(c *http.Client) *OpenFIGI { // NewOpenFIGI creates an OpenFIGI client that uses the API key if provided
func NewOpenFIGI(c *http.Client, apiKey string) *OpenFIGI {
// Rate limits as per https://www.openfigi.com/api/documentation#rate-limits
limiter := rate.NewLimiter(rate.Every(time.Minute), 25)
if len(apiKey) > 0 {
slog.Debug("OpenFIGI client: created with API Key rate limits")
limiter = rate.NewLimiter(rate.Every(time.Second*6), 25)
} else {
slog.Debug("OpenFIGI client: created with puplic rate limits")
}
return &OpenFIGI{ return &OpenFIGI{
client: c, client: c,
mappingLimiter: rate.NewLimiter(rate.Every(time.Minute), 25), // https://www.openfigi.com/api/documentation#rate-limits apiKey: apiKey,
mappingLimiter: limiter,
securityTypeCache: make(map[string]string), securityTypeCache: make(map[string]string),
} }
} }
@@ -38,10 +52,16 @@ func (of *OpenFIGI) SecurityTypeByISIN(ctx context.Context, isin string) (string
of.mu.RLock() of.mu.RLock()
if secType, ok := of.securityTypeCache[isin]; ok { if secType, ok := of.securityTypeCache[isin]; ok {
of.mu.RUnlock() of.mu.RUnlock()
slog.Debug("OpenFIGI client: SecurityTypeByISIN cache hit",
slog.String("isin", isin),
slog.String("security_type", secType))
return secType, nil return secType, nil
} }
of.mu.RUnlock() of.mu.RUnlock()
slog.Debug("OpenFIGI client: SecurityTypeByISIN cache miss",
slog.String("isin", isin))
of.mu.Lock() of.mu.Lock()
defer of.mu.Unlock() defer of.mu.Unlock()
@@ -71,6 +91,14 @@ func (of *OpenFIGI) SecurityTypeByISIN(ctx context.Context, isin string) (string
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
if len(of.apiKey) > 0 {
req.Header.Add(OpenFIGIAPIKeyHeader, of.apiKey)
}
if !of.mappingLimiter.Allow() {
slog.Debug("OpenFIGI client: mapping limiter waiting for rate limiter capacity")
}
err = of.mappingLimiter.Wait(ctx) err = of.mappingLimiter.Wait(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("wait for mapping request capacity: %w", err) return "", fmt.Errorf("wait for mapping request capacity: %w", err)
@@ -109,6 +137,10 @@ func (of *OpenFIGI) SecurityTypeByISIN(ctx context.Context, isin string) (string
of.securityTypeCache[isin] = secType of.securityTypeCache[isin] = secType
slog.Debug("OpenFIGI client: SecurityTypeByISIN cached mapping",
slog.String("isin", isin),
slog.String("security_type", secType))
return secType, nil return secType, nil
} }

View File

@@ -110,7 +110,7 @@ func TestOpenFIGI_SecurityTypeByISIN(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
of := internal.NewOpenFIGI(tt.client) of := internal.NewOpenFIGI(tt.client, "")
got, gotErr := of.SecurityTypeByISIN(context.Background(), tt.isin) got, gotErr := of.SecurityTypeByISIN(context.Background(), tt.isin)
if gotErr != nil { if gotErr != nil {
@@ -145,7 +145,7 @@ func TestOpenFIGI_SecurityTypeByISIN_Cache(t *testing.T) {
}, nil }, nil
}) })
of := internal.NewOpenFIGI(c) of := internal.NewOpenFIGI(c, "")
got, gotErr := of.SecurityTypeByISIN(t.Context(), "NL0000235190") got, gotErr := of.SecurityTypeByISIN(t.Context(), "NL0000235190")
if gotErr != nil { if gotErr != nil {
@@ -166,6 +166,55 @@ func TestOpenFIGI_SecurityTypeByISIN_Cache(t *testing.T) {
} }
} }
func TestOpenFIGI_SecurityTypeByISIN_APIKey(t *testing.T) {
t.Run("with API key", func(t *testing.T) {
wantAPIKey := "123abc-456xyz"
c := NewTestClient(t, func(req *http.Request) (*http.Response, error) {
value, ok := req.Header[internal.OpenFIGIAPIKeyHeader]
if !ok {
t.Fatalf("want %q header but got none: %v", internal.OpenFIGIAPIKeyHeader, req.Header)
}
if len(value) != 1 {
t.Fatalf("want exactly one %q header value but got %d", internal.OpenFIGIAPIKeyHeader, len(value))
}
if value[0] != wantAPIKey {
t.Fatalf("want %q header value %q but got %q", internal.OpenFIGIAPIKeyHeader, wantAPIKey, value[0])
}
return &http.Response{
Status: http.StatusText(http.StatusOK),
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`[{"data":[{"securityType":"Common Stock"}]}]`)),
}, nil
})
of := internal.NewOpenFIGI(c, wantAPIKey)
_, err := of.SecurityTypeByISIN(t.Context(), "US1234567890")
if err != nil {
t.Fatalf("want success but got an error: %s", err)
}
})
t.Run("without API key", func(t *testing.T) {
c := NewTestClient(t, func(req *http.Request) (*http.Response, error) {
_, ok := req.Header[internal.OpenFIGIAPIKeyHeader]
if ok {
t.Fatalf("want no %s header but got one", internal.OpenFIGIAPIKeyHeader)
}
return &http.Response{
Status: http.StatusText(http.StatusOK),
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(`[{"data":[{"securityType":"Common Stock"}]}]`)),
}, nil
})
of := internal.NewOpenFIGI(c, "")
_, err := of.SecurityTypeByISIN(t.Context(), "US1234567890")
if err != nil {
t.Fatalf("want success but got an error: %s", err)
}
})
}
type RoundTripFunc func(req *http.Request) (*http.Response, error) type RoundTripFunc func(req *http.Request) (*http.Response, error)
func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"time" "time"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
@@ -15,7 +16,7 @@ type Record interface {
Nature() Nature Nature() Nature
BrokerCountry() int64 BrokerCountry() int64
AssetCountry() int64 AssetCountry() int64
Side() Side Kind() Kind
Price() decimal.Decimal Price() decimal.Decimal
Quantity() decimal.Decimal Quantity() decimal.Decimal
Timestamp() time.Time Timestamp() time.Time
@@ -54,14 +55,25 @@ type ReportWriter interface {
type Selector func(Record) bool type Selector func(Record) bool
// BuildReport reads records from a RecordReader and, if the record passes the Selector, it is // BuildReport reads records from a RecordReader and, if the record passes the Selector, it is
// processed into the report // processed into the ReportWriter.
func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter, s Selector) error { func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter, sel Selector) error {
buys := make(map[string]*FillerQueue) buys := make(map[string]*FillerQueue)
var buysCount, sellsCount int64
var lastTimestamp time.Time
progTicker := time.NewTicker(10 * time.Second)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-progTicker.C:
slog.InfoContext(ctx, "Progress update",
slog.Int64("total_records", buysCount+sellsCount),
slog.Int64("sell_records", sellsCount),
slog.Int64("buy_records", buysCount),
slog.Time("last_record_timestamp", lastTimestamp),
)
default: default:
rec, err := reader.ReadRecord(ctx) rec, err := reader.ReadRecord(ctx)
if err != nil { if err != nil {
@@ -71,30 +83,52 @@ func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter,
return err return err
} }
if !s(rec) { if rec.Kind().Is(KindBuy) {
continue buysCount++
} else if rec.Kind().Is(KindSell) {
sellsCount++
} }
lastTimestamp = rec.Timestamp()
buyQueue, ok := buys[rec.Symbol()] buyQueue, ok := buys[rec.Symbol()]
if !ok { if !ok {
buyQueue = new(FillerQueue) buyQueue = new(FillerQueue)
buys[rec.Symbol()] = buyQueue buys[rec.Symbol()] = buyQueue
} }
err = processRecord(ctx, buyQueue, rec, writer) err = processRecord(ctx, buyQueue, rec, sel, writer)
if err != nil { if err != nil {
return fmt.Errorf("processing record: %w", err) return fmt.Errorf("processing record: %w", err)
} }
} }
} }
} }
func processRecord(ctx context.Context, q *FillerQueue, rec Record, writer ReportWriter) error { // processRecord either adds buys to the queue or consumes buys from the queue when processing a
switch rec.Side() { // sell record.
case SideBuy: // Selectors are only applied on sells for performance reasons. It's much cheaper to just accumulate
// buys and only actually inspect a record once a sell happens due to potential network requests to
func processRecord(ctx context.Context, q *FillerQueue, rec Record, sel Selector, writer ReportWriter) error {
slog.Debug("Report: processing record",
slog.String("symbol", rec.Symbol()),
slog.String("side", rec.Kind().String()),
)
switch rec.Kind() {
case KindBuy:
q.Push(NewFiller(rec)) q.Push(NewFiller(rec))
case SideSell: case KindSell:
if !sel(rec) {
slog.Debug("Report: skipping record",
slog.String("symbol", rec.Symbol()),
slog.String("side", rec.Kind().String()),
)
return nil
}
unmatchedQty := rec.Quantity() unmatchedQty := rec.Quantity()
for unmatchedQty.IsPositive() { for unmatchedQty.IsPositive() {
@@ -134,8 +168,11 @@ func processRecord(ctx context.Context, q *FillerQueue, rec Record, writer Repor
} }
} }
case KindSplit:
q.AdjustForSplit(rec.Quantity())
default: default:
return fmt.Errorf("unknown side: %v", rec.Side()) return fmt.Errorf("unknown side: %v", rec.Kind())
} }
return nil return nil

View File

@@ -20,8 +20,8 @@ func TestBuildReport(t *testing.T) {
reader := mocks.NewMockRecordReader(ctrl) reader := mocks.NewMockRecordReader(ctrl)
records := []internal.Record{ records := []internal.Record{
mockRecord(ctrl, 20.0, 10.0, internal.SideBuy, now), mockRecord(ctrl, 20.0, 10.0, internal.KindBuy, now),
mockRecord(ctrl, 25.0, 10.0, internal.SideSell, now.Add(1)), mockRecord(ctrl, 25.0, 10.0, internal.KindSell, now.Add(1)),
} }
reader.EXPECT().ReadRecord(gomock.Any()).DoAndReturn(func(ctx context.Context) (internal.Record, error) { reader.EXPECT().ReadRecord(gomock.Any()).DoAndReturn(func(ctx context.Context) (internal.Record, error) {
if len(records) > 0 { if len(records) > 0 {
@@ -49,14 +49,14 @@ func TestBuildReport(t *testing.T) {
} }
} }
func mockRecord(ctrl *gomock.Controller, price, quantity float64, side internal.Side, ts time.Time) *mocks.MockRecord { func mockRecord(ctrl *gomock.Controller, price, quantity float64, kind internal.Kind, ts time.Time) *mocks.MockRecord {
rec := mocks.NewMockRecord(ctrl) rec := mocks.NewMockRecord(ctrl)
rec.EXPECT().Symbol().Return("TEST").AnyTimes() rec.EXPECT().Symbol().Return("TEST").AnyTimes()
rec.EXPECT().BrokerCountry().Return(int64(countries.PT)).AnyTimes() rec.EXPECT().BrokerCountry().Return(int64(countries.PT)).AnyTimes()
rec.EXPECT().AssetCountry().Return(int64(countries.USA)).AnyTimes() rec.EXPECT().AssetCountry().Return(int64(countries.USA)).AnyTimes()
rec.EXPECT().Price().Return(decimal.NewFromFloat(price)).AnyTimes() rec.EXPECT().Price().Return(decimal.NewFromFloat(price)).AnyTimes()
rec.EXPECT().Quantity().Return(decimal.NewFromFloat(quantity)).AnyTimes() rec.EXPECT().Quantity().Return(decimal.NewFromFloat(quantity)).AnyTimes()
rec.EXPECT().Side().Return(side).AnyTimes() rec.EXPECT().Kind().Return(kind).AnyTimes()
rec.EXPECT().Timestamp().Return(ts).AnyTimes() rec.EXPECT().Timestamp().Return(ts).AnyTimes()
rec.EXPECT().Fees().Return(decimal.Decimal{}).AnyTimes() rec.EXPECT().Fees().Return(decimal.Decimal{}).AnyTimes()
rec.EXPECT().Taxes().Return(decimal.Decimal{}).AnyTimes() rec.EXPECT().Taxes().Return(decimal.Decimal{}).AnyTimes()

View File

@@ -13,7 +13,7 @@ type testRecord struct {
nature internal.Nature nature internal.Nature
brokerCountry int64 brokerCountry int64
assetCountry int64 assetCountry int64
side internal.Side side internal.Kind
price decimal.Decimal price decimal.Decimal
quantity decimal.Decimal quantity decimal.Decimal
timestamp time.Time timestamp time.Time
@@ -25,7 +25,7 @@ func (m testRecord) Symbol() string { return m.symbol }
func (m testRecord) Nature() internal.Nature { return m.nature } func (m testRecord) Nature() internal.Nature { return m.nature }
func (m testRecord) BrokerCountry() int64 { return m.brokerCountry } func (m testRecord) BrokerCountry() int64 { return m.brokerCountry }
func (m testRecord) AssetCountry() int64 { return m.assetCountry } func (m testRecord) AssetCountry() int64 { return m.assetCountry }
func (m testRecord) Side() internal.Side { return m.side } func (m testRecord) Kind() internal.Kind { return m.side }
func (m testRecord) Price() decimal.Decimal { return m.price } func (m testRecord) Price() decimal.Decimal { return m.price }
func (m testRecord) Quantity() decimal.Decimal { return m.quantity } func (m testRecord) Quantity() decimal.Decimal { return m.quantity }
func (m testRecord) Timestamp() time.Time { return m.timestamp } func (m testRecord) Timestamp() time.Time { return m.timestamp }

View File

@@ -1,30 +0,0 @@
package internal
type Side uint
const (
SideUnknown Side = iota
SideBuy
SideSell
)
func (d Side) String() string {
switch d {
case SideBuy:
return "buy"
case SideSell:
return "sell"
default:
return "unknown"
}
}
// IsBuy returns true if the s == SideBuy
func (d Side) IsBuy() bool {
return d == SideBuy
}
// IsSell returns true if the s == SideSell
func (d Side) IsSell() bool {
return d == SideSell
}

View File

@@ -18,7 +18,7 @@ import (
type Record struct { type Record struct {
symbol string symbol string
timestamp time.Time timestamp time.Time
side internal.Side kind internal.Kind
quantity decimal.Decimal quantity decimal.Decimal
price decimal.Decimal price decimal.Decimal
fees decimal.Decimal fees decimal.Decimal
@@ -44,8 +44,8 @@ func (r Record) AssetCountry() int64 {
return int64(countries.ByName(r.Symbol()[:2]).Info().Code) return int64(countries.ByName(r.Symbol()[:2]).Info().Code)
} }
func (r Record) Side() internal.Side { func (r Record) Kind() internal.Kind {
return r.side return r.kind
} }
func (r Record) Quantity() decimal.Decimal { func (r Record) Quantity() decimal.Decimal {
@@ -85,25 +85,22 @@ const (
MarketSell = "market sell" MarketSell = "market sell"
LimitBuy = "limit buy" LimitBuy = "limit buy"
LimitSell = "limit sell" LimitSell = "limit sell"
StockSplitOpen = "stock split open"
StockSplitClose = "stock split close"
StokDistribution = "stock distribution"
) )
func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error) { func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error) {
var splitRec *splitRecord
for { for {
raw, err := rr.reader.Read() raw, err := rr.reader.Read()
if err != nil { if err != nil {
return Record{}, fmt.Errorf("read record: %w", err) return Record{}, fmt.Errorf("read record: %w", err)
} }
var side internal.Side if strings.ToLower(raw[0]) == "action" {
switch strings.ToLower(raw[0]) {
case MarketBuy, LimitBuy:
side = internal.SideBuy
case MarketSell, LimitSell:
side = internal.SideSell
case "action", "stock split open", "stock split close":
continue continue
default:
return Record{}, fmt.Errorf("parse record type: %s", raw[0])
} }
qant, err := parseDecimal(raw[6]) qant, err := parseDecimal(raw[6])
@@ -136,9 +133,50 @@ func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error)
return Record{}, fmt.Errorf("parse record french transaction tax: %w", err) return Record{}, fmt.Errorf("parse record french transaction tax: %w", err)
} }
var kind internal.Kind
switch strings.ToLower(raw[0]) {
case MarketBuy, LimitBuy:
kind = internal.KindBuy
case MarketSell, LimitSell:
kind = internal.KindSell
case StockSplitOpen:
if splitRec != nil {
return nil, fmt.Errorf("split already open")
}
splitRec = &splitRecord{
Record: Record{
symbol: raw[2],
kind: internal.KindSplit,
quantity: qant,
price: price,
fees: conversionFee,
taxes: stampDutyTax.Add(frenchTxTax),
timestamp: ts,
natureGetter: figiNatureGetter(ctx, rr.figi, raw[2]),
},
}
continue
case StockSplitClose:
if splitRec == nil {
return nil, fmt.Errorf("missing split open")
}
splitRec.ratio = splitRec.Record.Quantity().Div(qant)
return splitRec, nil
case StokDistribution:
slog.Warn("Found stock distribution but can't handle it")
continue
default:
return Record{}, fmt.Errorf("parse record type: %s", raw[0])
}
return Record{ return Record{
symbol: raw[2], symbol: raw[2],
side: side, kind: kind,
quantity: qant, quantity: qant,
price: price, price: price,
fees: conversionFee, fees: conversionFee,
@@ -158,7 +196,7 @@ func figiNatureGetter(ctx context.Context, of *internal.OpenFIGI, isin string) f
} }
switch secType { switch secType {
case "Common Stock": case "Common Stock", "ADR", "REIT":
return internal.NatureG01 return internal.NatureG01
case "ETP": case "ETP":
return internal.NatureG20 return internal.NatureG20
@@ -185,3 +223,13 @@ func parseOptionalDecimal(s string) (decimal.Decimal, error) {
return parseDecimal(s) return parseDecimal(s)
} }
type splitRecord struct {
Record
ratio decimal.Decimal
}
func (sr splitRecord) Quantity() decimal.Decimal {
return sr.ratio
}

View File

@@ -30,7 +30,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
r: bytes.NewBufferString(`Market buy,2025-07-03 10:44:29,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.3690000000,USD,1.17995999,,"EUR",15.25,"EUR",0.25,"EUR",0.02,"EUR",,`), r: bytes.NewBufferString(`Market buy,2025-07-03 10:44:29,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.3690000000,USD,1.17995999,,"EUR",15.25,"EUR",0.25,"EUR",0.02,"EUR",,`),
want: Record{ want: Record{
symbol: "XX1234567890", symbol: "XX1234567890",
side: internal.SideBuy, kind: internal.KindBuy,
quantity: ShouldParseDecimal(t, "2.4387014200"), quantity: ShouldParseDecimal(t, "2.4387014200"),
price: ShouldParseDecimal(t, "7.3690000000"), price: ShouldParseDecimal(t, "7.3690000000"),
timestamp: time.Date(2025, 7, 3, 10, 44, 29, 0, time.UTC), timestamp: time.Date(2025, 7, 3, 10, 44, 29, 0, time.UTC),
@@ -44,7 +44,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
r: bytes.NewBufferString(`Market sell,2025-08-04 11:45:30,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.9999999999,USD,1.17995999,,"EUR",15.25,"EUR",,,0.02,"EUR",0.1,"EUR"`), r: bytes.NewBufferString(`Market sell,2025-08-04 11:45:30,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.9999999999,USD,1.17995999,,"EUR",15.25,"EUR",,,0.02,"EUR",0.1,"EUR"`),
want: Record{ want: Record{
symbol: "XX1234567890", symbol: "XX1234567890",
side: internal.SideSell, kind: internal.KindSell,
quantity: ShouldParseDecimal(t, "2.4387014200"), quantity: ShouldParseDecimal(t, "2.4387014200"),
price: ShouldParseDecimal(t, "7.9999999999"), price: ShouldParseDecimal(t, "7.9999999999"),
timestamp: time.Date(2025, 8, 4, 11, 45, 30, 0, time.UTC), timestamp: time.Date(2025, 8, 4, 11, 45, 30, 0, time.UTC),
@@ -123,8 +123,8 @@ func TestRecordReader_ReadRecord(t *testing.T) {
t.Fatalf("want symbol %v but got %v", tt.want.symbol, got.Symbol()) t.Fatalf("want symbol %v but got %v", tt.want.symbol, got.Symbol())
} }
if got.Side() != tt.want.side { if got.Kind() != tt.want.kind {
t.Fatalf("want side %v but got %v", tt.want.side, got.Side()) t.Fatalf("want side %v but got %v", tt.want.kind, got.Kind())
} }
if got.Price().Cmp(tt.want.price) != 0 { if got.Price().Cmp(tt.want.price) != 0 {
@@ -154,6 +154,70 @@ func TestRecordReader_ReadRecord(t *testing.T) {
} }
} }
func TestRecordReader_ReadRecord_Split(t *testing.T) {
// open row has the NEW (post-split) position: more shares at lower price
// close row has the OLD (pre-split) position: fewer shares at higher price
// ratio = openQty / closeQty = 0.5 / 0.1 = 5 (a 5:1 split)
splitOpen := `Stock split open,2025-06-03 05:34:16,XX1234567890,ABXY,"Aspargus Broccoli",EOF111111111,0.5000000000,20.0000000000,EUR,1.00000000,,,10.00,"EUR",,,,,,`
splitClose := `Stock split close,2025-06-03 05:34:16,XX1234567890,ABXY,"Aspargus Broccoli",EOF222222222,0.1000000000,100.0000000000,EUR,1.00000000,0.00,"EUR",10.00,"EUR",,,,,,`
t.Run("well-formed split pair returns split record with correct ratio", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitOpen+"\n"+splitClose),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
got, err := rr.ReadRecord(t.Context())
if err != nil {
t.Fatalf("ReadRecord() failed: %v", err)
}
if got.Kind() != internal.KindSplit {
t.Errorf("want kind %v but got %v", internal.KindSplit, got.Kind())
}
if got.Symbol() != "XX1234567890" {
t.Errorf("want symbol NO0013536151 but got %v", got.Symbol())
}
wantTimestamp := time.Date(2025, 6, 3, 5, 34, 16, 0, time.UTC)
if !got.Timestamp().Equal(wantTimestamp) {
t.Errorf("want timestamp %v but got %v", wantTimestamp, got.Timestamp())
}
// ratio = openQty / closeQty = 0.1245045 / 0.0249009 ≈ 5
openQty := ShouldParseDecimal(t, "0.1245045000")
closeQty := ShouldParseDecimal(t, "0.0249009000")
wantRatio := openQty.Div(closeQty)
if !got.Quantity().Equal(wantRatio) {
t.Errorf("want ratio %v but got %v", wantRatio, got.Quantity())
}
})
t.Run("close without prior open errors", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitClose),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
_, err := rr.ReadRecord(t.Context())
if err == nil {
t.Fatal("expected error but got none")
}
})
t.Run("two opens without close errors", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitOpen+"\n"+splitOpen),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
_, err := rr.ReadRecord(t.Context())
if err == nil {
t.Fatal("expected error but got none")
}
})
}
func Test_figiNatureGetter(t *testing.T) { func Test_figiNatureGetter(t *testing.T) {
tests := []struct { tests := []struct {
name string // description of this test case name string // description of this test case
@@ -223,7 +287,7 @@ func NewFigiClientSecurityTypeStub(t testing.TB, securityType string) *internal.
}), }),
} }
return internal.NewOpenFIGI(c) return internal.NewOpenFIGI(c, "")
} }
func NewFigiClientErrorStub(t testing.TB, err error) *internal.OpenFIGI { func NewFigiClientErrorStub(t testing.TB, err error) *internal.OpenFIGI {
@@ -236,5 +300,5 @@ func NewFigiClientErrorStub(t testing.TB, err error) *internal.OpenFIGI {
}), }),
} }
return internal.NewOpenFIGI(c) return internal.NewOpenFIGI(c, "")
} }