Compare commits

...

2 Commits

Author SHA1 Message Date
7cc5d1cf75 handle stock split
Some checks failed
Generate check / check-changes (pull_request) Successful in 3s
Quality / check-changes (pull_request) Successful in 3s
Generate check / verify-generate (pull_request) Successful in 16s
Quality / run-tests (pull_request) Failing after 10s
2026-05-16 15:34:24 +01:00
5cdccfcdb1 rename side to kind 2026-05-16 14:21:14 +01:00
11 changed files with 384 additions and 157 deletions

View File

@@ -9,19 +9,27 @@ import (
type Filler struct { type Filler struct {
Record Record
filled decimal.Decimal filled decimal.Decimal
quantity decimal.Decimal
price decimal.Decimal
} }
func NewFiller(r Record) *Filler { func NewFiller(r Record) *Filler {
return &Filler{ return &Filler{
Record: r, Record: r,
quantity: r.Quantity(),
price: r.Price(),
} }
} }
func (f *Filler) Quantity() decimal.Decimal { return f.quantity }
func (f *Filler) Price() decimal.Decimal { return f.price }
// Fill accrues some quantity. Returns how mutch was accrued in the 1st return value and whether // Fill accrues some quantity. Returns how mutch was accrued in the 1st return value and whether
// it was filled or not on the 2nd return value. // it was filled or not on the 2nd return value.
func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) { func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) {
unfilled := f.Record.Quantity().Sub(f.filled) unfilled := f.quantity.Sub(f.filled)
delta := decimal.Min(unfilled, quantity) delta := decimal.Min(unfilled, quantity)
f.filled = f.filled.Add(delta) f.filled = f.filled.Add(delta)
return delta, f.IsFilled() return delta, f.IsFilled()
@@ -29,7 +37,15 @@ func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) {
// IsFilled returns true if the fill is equal to the record quantity. // IsFilled returns true if the fill is equal to the record quantity.
func (f *Filler) IsFilled() bool { func (f *Filler) IsFilled() bool {
return f.filled.Equal(f.Quantity()) return f.filled.Equal(f.quantity)
}
// ApplySplit adjusts the lot for a stock split by the given ratio (newQty/oldQty).
// The total cost basis is preserved: quantity scales up, price scales down proportionally.
func (f *Filler) ApplySplit(ratio decimal.Decimal) {
f.quantity = f.quantity.Mul(ratio)
f.filled = f.filled.Mul(ratio)
f.price = f.price.Div(ratio)
} }
type FillerQueue struct { type FillerQueue struct {
@@ -86,6 +102,16 @@ func (fq *FillerQueue) frontElement() *list.Element {
return fq.l.Front() return fq.l.Front()
} }
// AdjustForSplit applies a stock split ratio to all lots in the queue.
func (fq *FillerQueue) AdjustForSplit(ratio decimal.Decimal) {
if fq == nil || fq.l == nil {
return
}
for e := fq.l.Front(); e != nil; e = e.Next() {
e.Value.(*Filler).ApplySplit(ratio)
}
}
// Len returns how many elements are currently on the queue // Len returns how many elements are currently on the queue
func (fq *FillerQueue) Len() int { func (fq *FillerQueue) Len() int {
if fq == nil || fq.l == nil { if fq == nil || fq.l == nil {

View File

@@ -114,7 +114,7 @@ func TestFillerQueueNilReceiver(t *testing.T) {
t.Fatalf(`want panic message %q but got "%v"`, expMsg, r) t.Fatalf(`want panic message %q but got "%v"`, expMsg, r)
} }
}() }()
rq.Push(NewFiller(nil)) rq.Push(NewFiller(&testRecord{}))
} }
type testRecord struct { type testRecord struct {
@@ -122,11 +122,11 @@ type testRecord struct {
id int id int
quantity decimal.Decimal quantity decimal.Decimal
price decimal.Decimal
} }
func (tr testRecord) Quantity() decimal.Decimal { func (tr testRecord) Quantity() decimal.Decimal { return tr.quantity }
return tr.quantity func (tr testRecord) Price() decimal.Decimal { return tr.price }
}
func TestFiller_Fill(t *testing.T) { func TestFiller_Fill(t *testing.T) {
tests := []struct { tests := []struct {
@@ -185,3 +185,89 @@ func TestFiller_Fill(t *testing.T) {
}) })
} }
} }
func TestFiller_ApplySplit(t *testing.T) {
tests := []struct {
name string
qty float64
price float64
prefilled float64
ratio float64
wantQty float64
wantPrice float64
wantFilled float64
wantCostBasis float64
}{
{
name: "5:1 split on unfilled lot preserves cost basis",
qty: 10, price: 100, prefilled: 0, ratio: 5,
wantQty: 50, wantPrice: 20, wantFilled: 0, wantCostBasis: 1000,
},
{
name: "5:1 split on partially filled lot",
qty: 10, price: 100, prefilled: 4, ratio: 5,
wantQty: 50, wantPrice: 20, wantFilled: 20, wantCostBasis: 1000,
},
{
name: "1:2 reverse split on unfilled lot preserves cost basis",
qty: 10, price: 100, prefilled: 0, ratio: 0.5,
wantQty: 5, wantPrice: 200, wantFilled: 0, wantCostBasis: 1000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := NewFiller(&testRecord{
quantity: decimal.NewFromFloat(tt.qty),
price: decimal.NewFromFloat(tt.price),
})
if tt.prefilled > 0 {
f.Fill(decimal.NewFromFloat(tt.prefilled))
}
f.ApplySplit(decimal.NewFromFloat(tt.ratio))
if !f.Quantity().Equal(decimal.NewFromFloat(tt.wantQty)) {
t.Errorf("want quantity %v but got %v", tt.wantQty, f.Quantity())
}
if !f.Price().Equal(decimal.NewFromFloat(tt.wantPrice)) {
t.Errorf("want price %v but got %v", tt.wantPrice, f.Price())
}
if !f.filled.Equal(decimal.NewFromFloat(tt.wantFilled)) {
t.Errorf("want filled %v but got %v", tt.wantFilled, f.filled)
}
costBasis := f.Quantity().Mul(f.Price())
if !costBasis.Equal(decimal.NewFromFloat(tt.wantCostBasis)) {
t.Errorf("want cost basis %v but got %v", tt.wantCostBasis, costBasis)
}
})
}
}
func TestFillerQueue_AdjustForSplit(t *testing.T) {
var fq FillerQueue
fq.Push(NewFiller(&testRecord{quantity: decimal.NewFromFloat(10), price: decimal.NewFromFloat(100)}))
fq.Push(NewFiller(&testRecord{quantity: decimal.NewFromFloat(5), price: decimal.NewFromFloat(200)}))
fq.AdjustForSplit(decimal.NewFromFloat(5))
lot1, _ := fq.Pop()
if !lot1.Quantity().Equal(decimal.NewFromFloat(50)) {
t.Errorf("lot1: want quantity 50 but got %v", lot1.Quantity())
}
if !lot1.Price().Equal(decimal.NewFromFloat(20)) {
t.Errorf("lot1: want price 20 but got %v", lot1.Price())
}
lot2, _ := fq.Pop()
if !lot2.Quantity().Equal(decimal.NewFromFloat(25)) {
t.Errorf("lot2: want quantity 25 but got %v", lot2.Quantity())
}
if !lot2.Price().Equal(decimal.NewFromFloat(40)) {
t.Errorf("lot2: want price 40 but got %v", lot2.Price())
}
}
func TestFillerQueue_AdjustForSplit_NilReceiver(t *testing.T) {
var fq *FillerQueue
fq.AdjustForSplit(decimal.NewFromFloat(5)) // must not panic
}

30
internal/kind.go Normal file
View File

@@ -0,0 +1,30 @@
package internal
type Kind uint
const (
KindUnknown Kind = iota
KindBuy
KindSell
KindSplit
)
// String returns a human readable value
func (d Kind) String() string {
switch d {
case KindBuy:
return "buy"
case KindSell:
return "sell"
case KindSplit:
return "split"
default:
return "unknown"
}
}
// Is returns true when k equals o
func (k Kind) Is(o any) bool {
other, ok := o.(Kind)
return ok && k == other
}

View File

@@ -5,12 +5,12 @@ import "testing"
func TestSide_String(t *testing.T) { func TestSide_String(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
side Side side Kind
want string want string
}{ }{
{"buy", SideBuy, "buy"}, {"buy", KindBuy, "buy"},
{"sell", SideSell, "sell"}, {"sell", KindSell, "sell"},
{"unknown", SideUnknown, "unknown"}, {"unknown", KindUnknown, "unknown"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@@ -24,16 +24,16 @@ func TestSide_String(t *testing.T) {
func TestSide_IsBuy(t *testing.T) { func TestSide_IsBuy(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
side Side side Kind
want bool want bool
}{ }{
{"buy", SideBuy, true}, {"buy", KindBuy, true},
{"sell", SideSell, false}, {"sell", KindSell, false},
{"unknown", SideUnknown, false}, {"unknown", KindUnknown, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := tt.side.IsBuy(); got != tt.want { if got := tt.side.Is(KindBuy); got != tt.want {
t.Errorf("want Side.IsBuy() to be %v but got %v", tt.want, got) t.Errorf("want Side.IsBuy() to be %v but got %v", tt.want, got)
} }
}) })
@@ -43,16 +43,16 @@ func TestSide_IsBuy(t *testing.T) {
func TestSide_IsSell(t *testing.T) { func TestSide_IsSell(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
side Side side Kind
want bool want bool
}{ }{
{"buy", SideBuy, false}, {"buy", KindBuy, false},
{"sell", SideSell, true}, {"sell", KindSell, true},
{"unknown", SideUnknown, false}, {"unknown", KindUnknown, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := tt.side.IsSell(); got != tt.want { if got := tt.side.Is(KindSell); got != tt.want {
t.Errorf("want Side.IsSell() to be %v but got %v", tt.want, got) t.Errorf("want Side.IsSell() to be %v but got %v", tt.want, got)
} }
}) })

View File

@@ -220,6 +220,44 @@ func (c *MockRecordFeesCall) DoAndReturn(f func() decimal.Decimal) *MockRecordFe
return c return c
} }
// Kind mocks base method.
func (m *MockRecord) Kind() internal.Kind {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Kind")
ret0, _ := ret[0].(internal.Kind)
return ret0
}
// Kind indicates an expected call of Kind.
func (mr *MockRecordMockRecorder) Kind() *MockRecordKindCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kind", reflect.TypeOf((*MockRecord)(nil).Kind))
return &MockRecordKindCall{Call: call}
}
// MockRecordKindCall wrap *gomock.Call
type MockRecordKindCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockRecordKindCall) Return(arg0 internal.Kind) *MockRecordKindCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockRecordKindCall) Do(f func() internal.Kind) *MockRecordKindCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockRecordKindCall) DoAndReturn(f func() internal.Kind) *MockRecordKindCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Nature mocks base method. // Nature mocks base method.
func (m *MockRecord) Nature() internal.Nature { func (m *MockRecord) Nature() internal.Nature {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -334,44 +372,6 @@ func (c *MockRecordQuantityCall) DoAndReturn(f func() decimal.Decimal) *MockReco
return c return c
} }
// Side mocks base method.
func (m *MockRecord) Side() internal.Side {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Side")
ret0, _ := ret[0].(internal.Side)
return ret0
}
// Side indicates an expected call of Side.
func (mr *MockRecordMockRecorder) Side() *MockRecordSideCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Side", reflect.TypeOf((*MockRecord)(nil).Side))
return &MockRecordSideCall{Call: call}
}
// MockRecordSideCall wrap *gomock.Call
type MockRecordSideCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockRecordSideCall) Return(arg0 internal.Side) *MockRecordSideCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockRecordSideCall) Do(f func() internal.Side) *MockRecordSideCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockRecordSideCall) DoAndReturn(f func() internal.Side) *MockRecordSideCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Symbol mocks base method. // Symbol mocks base method.
func (m *MockRecord) Symbol() string { func (m *MockRecord) Symbol() string {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -16,7 +16,7 @@ type Record interface {
Nature() Nature Nature() Nature
BrokerCountry() int64 BrokerCountry() int64
AssetCountry() int64 AssetCountry() int64
Side() Side Kind() Kind
Price() decimal.Decimal Price() decimal.Decimal
Quantity() decimal.Decimal Quantity() decimal.Decimal
Timestamp() time.Time Timestamp() time.Time
@@ -83,9 +83,9 @@ func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter,
return err return err
} }
if rec.Side().IsBuy() { if rec.Kind().Is(KindBuy) {
buysCount++ buysCount++
} else { } else if rec.Kind().Is(KindSell) {
sellsCount++ sellsCount++
} }
@@ -108,23 +108,23 @@ func BuildReport(ctx context.Context, reader RecordReader, writer ReportWriter,
// processRecord either adds buys to the queue or consumes buys from the queue when processing a // processRecord either adds buys to the queue or consumes buys from the queue when processing a
// sell record. // sell record.
// Selectors are only applied for sells for performance reasons. It's much cheaper to just accumulate // Selectors are only applied on sells for performance reasons. It's much cheaper to just accumulate
// buys and only actually inspect a record once a sell happens // buys and only actually inspect a record once a sell happens due to potential network requests to
func processRecord(ctx context.Context, q *FillerQueue, rec Record, sel Selector, writer ReportWriter) error { func processRecord(ctx context.Context, q *FillerQueue, rec Record, sel Selector, writer ReportWriter) error {
slog.Debug("Report: processing record", slog.Debug("Report: processing record",
slog.String("symbol", rec.Symbol()), slog.String("symbol", rec.Symbol()),
slog.String("side", rec.Side().String()), slog.String("side", rec.Kind().String()),
) )
switch rec.Side() { switch rec.Kind() {
case SideBuy: case KindBuy:
q.Push(NewFiller(rec)) q.Push(NewFiller(rec))
case SideSell: case KindSell:
if !sel(rec) { if !sel(rec) {
slog.Debug("Report: skipping record", slog.Debug("Report: skipping record",
slog.String("symbol", rec.Symbol()), slog.String("symbol", rec.Symbol()),
slog.String("side", rec.Side().String()), slog.String("side", rec.Kind().String()),
) )
return nil return nil
} }
@@ -168,8 +168,11 @@ func processRecord(ctx context.Context, q *FillerQueue, rec Record, sel Selector
} }
} }
case KindSplit:
q.AdjustForSplit(rec.Quantity())
default: default:
return fmt.Errorf("unknown side: %v", rec.Side()) return fmt.Errorf("unknown side: %v", rec.Kind())
} }
return nil return nil

View File

@@ -20,8 +20,8 @@ func TestBuildReport(t *testing.T) {
reader := mocks.NewMockRecordReader(ctrl) reader := mocks.NewMockRecordReader(ctrl)
records := []internal.Record{ records := []internal.Record{
mockRecord(ctrl, 20.0, 10.0, internal.SideBuy, now), mockRecord(ctrl, 20.0, 10.0, internal.KindBuy, now),
mockRecord(ctrl, 25.0, 10.0, internal.SideSell, now.Add(1)), mockRecord(ctrl, 25.0, 10.0, internal.KindSell, now.Add(1)),
} }
reader.EXPECT().ReadRecord(gomock.Any()).DoAndReturn(func(ctx context.Context) (internal.Record, error) { reader.EXPECT().ReadRecord(gomock.Any()).DoAndReturn(func(ctx context.Context) (internal.Record, error) {
if len(records) > 0 { if len(records) > 0 {
@@ -49,14 +49,14 @@ func TestBuildReport(t *testing.T) {
} }
} }
func mockRecord(ctrl *gomock.Controller, price, quantity float64, side internal.Side, ts time.Time) *mocks.MockRecord { func mockRecord(ctrl *gomock.Controller, price, quantity float64, kind internal.Kind, ts time.Time) *mocks.MockRecord {
rec := mocks.NewMockRecord(ctrl) rec := mocks.NewMockRecord(ctrl)
rec.EXPECT().Symbol().Return("TEST").AnyTimes() rec.EXPECT().Symbol().Return("TEST").AnyTimes()
rec.EXPECT().BrokerCountry().Return(int64(countries.PT)).AnyTimes() rec.EXPECT().BrokerCountry().Return(int64(countries.PT)).AnyTimes()
rec.EXPECT().AssetCountry().Return(int64(countries.USA)).AnyTimes() rec.EXPECT().AssetCountry().Return(int64(countries.USA)).AnyTimes()
rec.EXPECT().Price().Return(decimal.NewFromFloat(price)).AnyTimes() rec.EXPECT().Price().Return(decimal.NewFromFloat(price)).AnyTimes()
rec.EXPECT().Quantity().Return(decimal.NewFromFloat(quantity)).AnyTimes() rec.EXPECT().Quantity().Return(decimal.NewFromFloat(quantity)).AnyTimes()
rec.EXPECT().Side().Return(side).AnyTimes() rec.EXPECT().Kind().Return(kind).AnyTimes()
rec.EXPECT().Timestamp().Return(ts).AnyTimes() rec.EXPECT().Timestamp().Return(ts).AnyTimes()
rec.EXPECT().Fees().Return(decimal.Decimal{}).AnyTimes() rec.EXPECT().Fees().Return(decimal.Decimal{}).AnyTimes()
rec.EXPECT().Taxes().Return(decimal.Decimal{}).AnyTimes() rec.EXPECT().Taxes().Return(decimal.Decimal{}).AnyTimes()

View File

@@ -9,28 +9,28 @@ import (
) )
type testRecord struct { type testRecord struct {
symbol string symbol string
nature internal.Nature nature internal.Nature
brokerCountry int64 brokerCountry int64
assetCountry int64 assetCountry int64
side internal.Side side internal.Kind
price decimal.Decimal price decimal.Decimal
quantity decimal.Decimal quantity decimal.Decimal
timestamp time.Time timestamp time.Time
fees decimal.Decimal fees decimal.Decimal
taxes decimal.Decimal taxes decimal.Decimal
} }
func (m testRecord) Symbol() string { return m.symbol } func (m testRecord) Symbol() string { return m.symbol }
func (m testRecord) Nature() internal.Nature { return m.nature } func (m testRecord) Nature() internal.Nature { return m.nature }
func (m testRecord) BrokerCountry() int64 { return m.brokerCountry } func (m testRecord) BrokerCountry() int64 { return m.brokerCountry }
func (m testRecord) AssetCountry() int64 { return m.assetCountry } func (m testRecord) AssetCountry() int64 { return m.assetCountry }
func (m testRecord) Side() internal.Side { return m.side } func (m testRecord) Kind() internal.Kind { return m.side }
func (m testRecord) Price() decimal.Decimal { return m.price } func (m testRecord) Price() decimal.Decimal { return m.price }
func (m testRecord) Quantity() decimal.Decimal { return m.quantity } func (m testRecord) Quantity() decimal.Decimal { return m.quantity }
func (m testRecord) Timestamp() time.Time { return m.timestamp } func (m testRecord) Timestamp() time.Time { return m.timestamp }
func (m testRecord) Fees() decimal.Decimal { return m.fees } func (m testRecord) Fees() decimal.Decimal { return m.fees }
func (m testRecord) Taxes() decimal.Decimal { return m.taxes } func (m testRecord) Taxes() decimal.Decimal { return m.taxes }
func TestAny(t *testing.T) { func TestAny(t *testing.T) {
selector := internal.Any() selector := internal.Any()
@@ -43,9 +43,9 @@ func TestAny(t *testing.T) {
{ {
name: "returns true for any record", name: "returns true for any record",
record: testRecord{ record: testRecord{
symbol: "AAPL", symbol: "AAPL",
nature: internal.NatureG01, nature: internal.NatureG01,
assetCountry: 1, assetCountry: 1,
}, },
want: true, want: true,
}, },
@@ -59,7 +59,7 @@ func TestAny(t *testing.T) {
want: true, want: true,
}, },
{ {
name: "returns true for empty record", name: "returns true for empty record",
record: testRecord{}, record: testRecord{},
want: true, want: true,
}, },
@@ -77,10 +77,10 @@ func TestAny(t *testing.T) {
func TestOnlyNature(t *testing.T) { func TestOnlyNature(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
nature internal.Nature nature internal.Nature
record internal.Record record internal.Record
want bool want bool
}{ }{
{ {
name: "matches G01 nature", name: "matches G01 nature",
@@ -133,10 +133,10 @@ func TestOnlyNature(t *testing.T) {
func TestOnlyAssetCountry(t *testing.T) { func TestOnlyAssetCountry(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
country int64 country int64
record internal.Record record internal.Record
want bool want bool
}{ }{
{ {
name: "matches asset country", name: "matches asset country",

View File

@@ -1,30 +0,0 @@
package internal
type Side uint
const (
SideUnknown Side = iota
SideBuy
SideSell
)
func (d Side) String() string {
switch d {
case SideBuy:
return "buy"
case SideSell:
return "sell"
default:
return "unknown"
}
}
// IsBuy returns true if the s == SideBuy
func (d Side) IsBuy() bool {
return d == SideBuy
}
// IsSell returns true if the s == SideSell
func (d Side) IsSell() bool {
return d == SideSell
}

View File

@@ -18,7 +18,7 @@ import (
type Record struct { type Record struct {
symbol string symbol string
timestamp time.Time timestamp time.Time
side internal.Side kind internal.Kind
quantity decimal.Decimal quantity decimal.Decimal
price decimal.Decimal price decimal.Decimal
fees decimal.Decimal fees decimal.Decimal
@@ -44,8 +44,8 @@ func (r Record) AssetCountry() int64 {
return int64(countries.ByName(r.Symbol()[:2]).Info().Code) return int64(countries.ByName(r.Symbol()[:2]).Info().Code)
} }
func (r Record) Side() internal.Side { func (r Record) Kind() internal.Kind {
return r.side return r.kind
} }
func (r Record) Quantity() decimal.Decimal { func (r Record) Quantity() decimal.Decimal {
@@ -81,29 +81,26 @@ func NewRecordReader(r io.Reader, f *internal.OpenFIGI) *RecordReader {
} }
const ( const (
MarketBuy = "market buy" MarketBuy = "market buy"
MarketSell = "market sell" MarketSell = "market sell"
LimitBuy = "limit buy" LimitBuy = "limit buy"
LimitSell = "limit sell" LimitSell = "limit sell"
StockSplitOpen = "stock split open"
StockSplitClose = "stock split close"
StokDistribution = "stock distribution"
) )
func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error) { func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error) {
var splitRec *splitRecord
for { for {
raw, err := rr.reader.Read() raw, err := rr.reader.Read()
if err != nil { if err != nil {
return Record{}, fmt.Errorf("read record: %w", err) return Record{}, fmt.Errorf("read record: %w", err)
} }
var side internal.Side if strings.ToLower(raw[0]) == "action" {
switch strings.ToLower(raw[0]) {
case MarketBuy, LimitBuy:
side = internal.SideBuy
case MarketSell, LimitSell:
side = internal.SideSell
case "action", "stock split open", "stock split close":
continue continue
default:
return Record{}, fmt.Errorf("parse record type: %s", raw[0])
} }
qant, err := parseDecimal(raw[6]) qant, err := parseDecimal(raw[6])
@@ -136,9 +133,50 @@ func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error)
return Record{}, fmt.Errorf("parse record french transaction tax: %w", err) return Record{}, fmt.Errorf("parse record french transaction tax: %w", err)
} }
var kind internal.Kind
switch strings.ToLower(raw[0]) {
case MarketBuy, LimitBuy:
kind = internal.KindBuy
case MarketSell, LimitSell:
kind = internal.KindSell
case StockSplitOpen:
if splitRec != nil {
return nil, fmt.Errorf("split already open")
}
splitRec = &splitRecord{
Record: Record{
symbol: raw[2],
kind: internal.KindSplit,
quantity: qant,
price: price,
fees: conversionFee,
taxes: stampDutyTax.Add(frenchTxTax),
timestamp: ts,
natureGetter: figiNatureGetter(ctx, rr.figi, raw[2]),
},
}
continue
case StockSplitClose:
if splitRec == nil {
return nil, fmt.Errorf("missing split open")
}
splitRec.ratio = splitRec.Record.Quantity().Div(qant)
return splitRec, nil
case StokDistribution:
slog.Warn("Found stock distribution but can't handle it")
continue
default:
return Record{}, fmt.Errorf("parse record type: %s", raw[0])
}
return Record{ return Record{
symbol: raw[2], symbol: raw[2],
side: side, kind: kind,
quantity: qant, quantity: qant,
price: price, price: price,
fees: conversionFee, fees: conversionFee,
@@ -185,3 +223,13 @@ func parseOptionalDecimal(s string) (decimal.Decimal, error) {
return parseDecimal(s) return parseDecimal(s)
} }
type splitRecord struct {
Record
ratio decimal.Decimal
}
func (sr splitRecord) Quantity() decimal.Decimal {
return sr.ratio
}

View File

@@ -30,7 +30,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
r: bytes.NewBufferString(`Market buy,2025-07-03 10:44:29,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.3690000000,USD,1.17995999,,"EUR",15.25,"EUR",0.25,"EUR",0.02,"EUR",,`), r: bytes.NewBufferString(`Market buy,2025-07-03 10:44:29,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.3690000000,USD,1.17995999,,"EUR",15.25,"EUR",0.25,"EUR",0.02,"EUR",,`),
want: Record{ want: Record{
symbol: "XX1234567890", symbol: "XX1234567890",
side: internal.SideBuy, kind: internal.KindBuy,
quantity: ShouldParseDecimal(t, "2.4387014200"), quantity: ShouldParseDecimal(t, "2.4387014200"),
price: ShouldParseDecimal(t, "7.3690000000"), price: ShouldParseDecimal(t, "7.3690000000"),
timestamp: time.Date(2025, 7, 3, 10, 44, 29, 0, time.UTC), timestamp: time.Date(2025, 7, 3, 10, 44, 29, 0, time.UTC),
@@ -44,7 +44,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
r: bytes.NewBufferString(`Market sell,2025-08-04 11:45:30,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.9999999999,USD,1.17995999,,"EUR",15.25,"EUR",,,0.02,"EUR",0.1,"EUR"`), r: bytes.NewBufferString(`Market sell,2025-08-04 11:45:30,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.9999999999,USD,1.17995999,,"EUR",15.25,"EUR",,,0.02,"EUR",0.1,"EUR"`),
want: Record{ want: Record{
symbol: "XX1234567890", symbol: "XX1234567890",
side: internal.SideSell, kind: internal.KindSell,
quantity: ShouldParseDecimal(t, "2.4387014200"), quantity: ShouldParseDecimal(t, "2.4387014200"),
price: ShouldParseDecimal(t, "7.9999999999"), price: ShouldParseDecimal(t, "7.9999999999"),
timestamp: time.Date(2025, 8, 4, 11, 45, 30, 0, time.UTC), timestamp: time.Date(2025, 8, 4, 11, 45, 30, 0, time.UTC),
@@ -123,8 +123,8 @@ func TestRecordReader_ReadRecord(t *testing.T) {
t.Fatalf("want symbol %v but got %v", tt.want.symbol, got.Symbol()) t.Fatalf("want symbol %v but got %v", tt.want.symbol, got.Symbol())
} }
if got.Side() != tt.want.side { if got.Kind() != tt.want.kind {
t.Fatalf("want side %v but got %v", tt.want.side, got.Side()) t.Fatalf("want side %v but got %v", tt.want.kind, got.Kind())
} }
if got.Price().Cmp(tt.want.price) != 0 { if got.Price().Cmp(tt.want.price) != 0 {
@@ -154,6 +154,70 @@ func TestRecordReader_ReadRecord(t *testing.T) {
} }
} }
func TestRecordReader_ReadRecord_Split(t *testing.T) {
// open row has the NEW (post-split) position: more shares at lower price
// close row has the OLD (pre-split) position: fewer shares at higher price
// ratio = openQty / closeQty = 0.5 / 0.1 = 5 (a 5:1 split)
splitOpen := `Stock split open,2025-06-03 05:34:16,XX1234567890,ABXY,"Aspargus Broccoli",EOF111111111,0.5000000000,20.0000000000,EUR,1.00000000,,,10.00,"EUR",,,,,,`
splitClose := `Stock split close,2025-06-03 05:34:16,XX1234567890,ABXY,"Aspargus Broccoli",EOF222222222,0.1000000000,100.0000000000,EUR,1.00000000,0.00,"EUR",10.00,"EUR",,,,,,`
t.Run("well-formed split pair returns split record with correct ratio", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitOpen+"\n"+splitClose),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
got, err := rr.ReadRecord(t.Context())
if err != nil {
t.Fatalf("ReadRecord() failed: %v", err)
}
if got.Kind() != internal.KindSplit {
t.Errorf("want kind %v but got %v", internal.KindSplit, got.Kind())
}
if got.Symbol() != "NO0013536151" {
t.Errorf("want symbol NO0013536151 but got %v", got.Symbol())
}
wantTimestamp := time.Date(2025, 6, 3, 5, 34, 16, 0, time.UTC)
if !got.Timestamp().Equal(wantTimestamp) {
t.Errorf("want timestamp %v but got %v", wantTimestamp, got.Timestamp())
}
// ratio = openQty / closeQty = 0.1245045 / 0.0249009 ≈ 5
openQty := ShouldParseDecimal(t, "0.1245045000")
closeQty := ShouldParseDecimal(t, "0.0249009000")
wantRatio := openQty.Div(closeQty)
if !got.Quantity().Equal(wantRatio) {
t.Errorf("want ratio %v but got %v", wantRatio, got.Quantity())
}
})
t.Run("close without prior open errors", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitClose),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
_, err := rr.ReadRecord(t.Context())
if err == nil {
t.Fatal("expected error but got none")
}
})
t.Run("two opens without close errors", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitOpen+"\n"+splitOpen),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
_, err := rr.ReadRecord(t.Context())
if err == nil {
t.Fatal("expected error but got none")
}
})
}
func Test_figiNatureGetter(t *testing.T) { func Test_figiNatureGetter(t *testing.T) {
tests := []struct { tests := []struct {
name string // description of this test case name string // description of this test case