handle stock split
Some checks failed
Generate check / check-changes (pull_request) Successful in 3s
Quality / check-changes (pull_request) Successful in 3s
Generate check / verify-generate (pull_request) Successful in 16s
Quality / run-tests (pull_request) Failing after 10s

This commit is contained in:
2026-05-16 15:34:24 +01:00
parent 5cdccfcdb1
commit 7cc5d1cf75
6 changed files with 258 additions and 35 deletions

View File

@@ -9,19 +9,27 @@ import (
type Filler struct { type Filler struct {
Record Record
filled decimal.Decimal filled decimal.Decimal
quantity decimal.Decimal
price decimal.Decimal
} }
func NewFiller(r Record) *Filler { func NewFiller(r Record) *Filler {
return &Filler{ return &Filler{
Record: r, Record: r,
quantity: r.Quantity(),
price: r.Price(),
} }
} }
func (f *Filler) Quantity() decimal.Decimal { return f.quantity }
func (f *Filler) Price() decimal.Decimal { return f.price }
// Fill accrues some quantity. Returns how mutch was accrued in the 1st return value and whether // Fill accrues some quantity. Returns how mutch was accrued in the 1st return value and whether
// it was filled or not on the 2nd return value. // it was filled or not on the 2nd return value.
func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) { func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) {
unfilled := f.Record.Quantity().Sub(f.filled) unfilled := f.quantity.Sub(f.filled)
delta := decimal.Min(unfilled, quantity) delta := decimal.Min(unfilled, quantity)
f.filled = f.filled.Add(delta) f.filled = f.filled.Add(delta)
return delta, f.IsFilled() return delta, f.IsFilled()
@@ -29,7 +37,15 @@ func (f *Filler) Fill(quantity decimal.Decimal) (decimal.Decimal, bool) {
// IsFilled returns true if the fill is equal to the record quantity. // IsFilled returns true if the fill is equal to the record quantity.
func (f *Filler) IsFilled() bool { func (f *Filler) IsFilled() bool {
return f.filled.Equal(f.Quantity()) return f.filled.Equal(f.quantity)
}
// ApplySplit adjusts the lot for a stock split by the given ratio (newQty/oldQty).
// The total cost basis is preserved: quantity scales up, price scales down proportionally.
func (f *Filler) ApplySplit(ratio decimal.Decimal) {
f.quantity = f.quantity.Mul(ratio)
f.filled = f.filled.Mul(ratio)
f.price = f.price.Div(ratio)
} }
type FillerQueue struct { type FillerQueue struct {
@@ -86,6 +102,16 @@ func (fq *FillerQueue) frontElement() *list.Element {
return fq.l.Front() return fq.l.Front()
} }
// AdjustForSplit applies a stock split ratio to all lots in the queue.
func (fq *FillerQueue) AdjustForSplit(ratio decimal.Decimal) {
if fq == nil || fq.l == nil {
return
}
for e := fq.l.Front(); e != nil; e = e.Next() {
e.Value.(*Filler).ApplySplit(ratio)
}
}
// Len returns how many elements are currently on the queue // Len returns how many elements are currently on the queue
func (fq *FillerQueue) Len() int { func (fq *FillerQueue) Len() int {
if fq == nil || fq.l == nil { if fq == nil || fq.l == nil {

View File

@@ -114,7 +114,7 @@ func TestFillerQueueNilReceiver(t *testing.T) {
t.Fatalf(`want panic message %q but got "%v"`, expMsg, r) t.Fatalf(`want panic message %q but got "%v"`, expMsg, r)
} }
}() }()
rq.Push(NewFiller(nil)) rq.Push(NewFiller(&testRecord{}))
} }
type testRecord struct { type testRecord struct {
@@ -122,11 +122,11 @@ type testRecord struct {
id int id int
quantity decimal.Decimal quantity decimal.Decimal
price decimal.Decimal
} }
func (tr testRecord) Quantity() decimal.Decimal { func (tr testRecord) Quantity() decimal.Decimal { return tr.quantity }
return tr.quantity func (tr testRecord) Price() decimal.Decimal { return tr.price }
}
func TestFiller_Fill(t *testing.T) { func TestFiller_Fill(t *testing.T) {
tests := []struct { tests := []struct {
@@ -185,3 +185,89 @@ func TestFiller_Fill(t *testing.T) {
}) })
} }
} }
func TestFiller_ApplySplit(t *testing.T) {
tests := []struct {
name string
qty float64
price float64
prefilled float64
ratio float64
wantQty float64
wantPrice float64
wantFilled float64
wantCostBasis float64
}{
{
name: "5:1 split on unfilled lot preserves cost basis",
qty: 10, price: 100, prefilled: 0, ratio: 5,
wantQty: 50, wantPrice: 20, wantFilled: 0, wantCostBasis: 1000,
},
{
name: "5:1 split on partially filled lot",
qty: 10, price: 100, prefilled: 4, ratio: 5,
wantQty: 50, wantPrice: 20, wantFilled: 20, wantCostBasis: 1000,
},
{
name: "1:2 reverse split on unfilled lot preserves cost basis",
qty: 10, price: 100, prefilled: 0, ratio: 0.5,
wantQty: 5, wantPrice: 200, wantFilled: 0, wantCostBasis: 1000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := NewFiller(&testRecord{
quantity: decimal.NewFromFloat(tt.qty),
price: decimal.NewFromFloat(tt.price),
})
if tt.prefilled > 0 {
f.Fill(decimal.NewFromFloat(tt.prefilled))
}
f.ApplySplit(decimal.NewFromFloat(tt.ratio))
if !f.Quantity().Equal(decimal.NewFromFloat(tt.wantQty)) {
t.Errorf("want quantity %v but got %v", tt.wantQty, f.Quantity())
}
if !f.Price().Equal(decimal.NewFromFloat(tt.wantPrice)) {
t.Errorf("want price %v but got %v", tt.wantPrice, f.Price())
}
if !f.filled.Equal(decimal.NewFromFloat(tt.wantFilled)) {
t.Errorf("want filled %v but got %v", tt.wantFilled, f.filled)
}
costBasis := f.Quantity().Mul(f.Price())
if !costBasis.Equal(decimal.NewFromFloat(tt.wantCostBasis)) {
t.Errorf("want cost basis %v but got %v", tt.wantCostBasis, costBasis)
}
})
}
}
func TestFillerQueue_AdjustForSplit(t *testing.T) {
var fq FillerQueue
fq.Push(NewFiller(&testRecord{quantity: decimal.NewFromFloat(10), price: decimal.NewFromFloat(100)}))
fq.Push(NewFiller(&testRecord{quantity: decimal.NewFromFloat(5), price: decimal.NewFromFloat(200)}))
fq.AdjustForSplit(decimal.NewFromFloat(5))
lot1, _ := fq.Pop()
if !lot1.Quantity().Equal(decimal.NewFromFloat(50)) {
t.Errorf("lot1: want quantity 50 but got %v", lot1.Quantity())
}
if !lot1.Price().Equal(decimal.NewFromFloat(20)) {
t.Errorf("lot1: want price 20 but got %v", lot1.Price())
}
lot2, _ := fq.Pop()
if !lot2.Quantity().Equal(decimal.NewFromFloat(25)) {
t.Errorf("lot2: want quantity 25 but got %v", lot2.Quantity())
}
if !lot2.Price().Equal(decimal.NewFromFloat(40)) {
t.Errorf("lot2: want price 40 but got %v", lot2.Price())
}
}
func TestFillerQueue_AdjustForSplit_NilReceiver(t *testing.T) {
var fq *FillerQueue
fq.AdjustForSplit(decimal.NewFromFloat(5)) // must not panic
}

View File

@@ -24,6 +24,7 @@ func (d Kind) String() string {
} }
// Is returns true when k equals o // Is returns true when k equals o
func (k Kind) Is(o Kind) bool { func (k Kind) Is(o any) bool {
return k == o other, ok := o.(Kind)
return ok && k == other
} }

View File

@@ -168,6 +168,9 @@ func processRecord(ctx context.Context, q *FillerQueue, rec Record, sel Selector
} }
} }
case KindSplit:
q.AdjustForSplit(rec.Quantity())
default: default:
return fmt.Errorf("unknown side: %v", rec.Kind()) return fmt.Errorf("unknown side: %v", rec.Kind())
} }

View File

@@ -18,7 +18,7 @@ import (
type Record struct { type Record struct {
symbol string symbol string
timestamp time.Time timestamp time.Time
side internal.Kind kind internal.Kind
quantity decimal.Decimal quantity decimal.Decimal
price decimal.Decimal price decimal.Decimal
fees decimal.Decimal fees decimal.Decimal
@@ -45,7 +45,7 @@ func (r Record) AssetCountry() int64 {
} }
func (r Record) Kind() internal.Kind { func (r Record) Kind() internal.Kind {
return r.side return r.kind
} }
func (r Record) Quantity() decimal.Decimal { func (r Record) Quantity() decimal.Decimal {
@@ -81,34 +81,26 @@ func NewRecordReader(r io.Reader, f *internal.OpenFIGI) *RecordReader {
} }
const ( const (
MarketBuy = "market buy" MarketBuy = "market buy"
MarketSell = "market sell" MarketSell = "market sell"
LimitBuy = "limit buy" LimitBuy = "limit buy"
LimitSell = "limit sell" LimitSell = "limit sell"
StockSplitOpen = "Stock split open" StockSplitOpen = "stock split open"
StockSplitClose = "Stock split close" StockSplitClose = "stock split close"
StokDistribution = "stock distribution"
) )
func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error) { func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error) {
var splitRec *splitRecord
for { for {
raw, err := rr.reader.Read() raw, err := rr.reader.Read()
if err != nil { if err != nil {
return Record{}, fmt.Errorf("read record: %w", err) return Record{}, fmt.Errorf("read record: %w", err)
} }
var side internal.Kind if strings.ToLower(raw[0]) == "action" {
switch strings.ToLower(raw[0]) {
case MarketBuy, LimitBuy:
side = internal.KindBuy
case MarketSell, LimitSell:
side = internal.KindSell
case StockSplitOpen, StockSplitClose:
// TODO: emit a special event that triggers a readjustment of unsold stock
continue continue
case "action": // TODO: this is the header, there's probably a better way to handle this
continue
default:
return Record{}, fmt.Errorf("parse record type: %s", raw[0])
} }
qant, err := parseDecimal(raw[6]) qant, err := parseDecimal(raw[6])
@@ -141,9 +133,50 @@ func (rr RecordReader) ReadRecord(ctx context.Context) (internal.Record, error)
return Record{}, fmt.Errorf("parse record french transaction tax: %w", err) return Record{}, fmt.Errorf("parse record french transaction tax: %w", err)
} }
var kind internal.Kind
switch strings.ToLower(raw[0]) {
case MarketBuy, LimitBuy:
kind = internal.KindBuy
case MarketSell, LimitSell:
kind = internal.KindSell
case StockSplitOpen:
if splitRec != nil {
return nil, fmt.Errorf("split already open")
}
splitRec = &splitRecord{
Record: Record{
symbol: raw[2],
kind: internal.KindSplit,
quantity: qant,
price: price,
fees: conversionFee,
taxes: stampDutyTax.Add(frenchTxTax),
timestamp: ts,
natureGetter: figiNatureGetter(ctx, rr.figi, raw[2]),
},
}
continue
case StockSplitClose:
if splitRec == nil {
return nil, fmt.Errorf("missing split open")
}
splitRec.ratio = splitRec.Record.Quantity().Div(qant)
return splitRec, nil
case StokDistribution:
slog.Warn("Found stock distribution but can't handle it")
continue
default:
return Record{}, fmt.Errorf("parse record type: %s", raw[0])
}
return Record{ return Record{
symbol: raw[2], symbol: raw[2],
side: side, kind: kind,
quantity: qant, quantity: qant,
price: price, price: price,
fees: conversionFee, fees: conversionFee,
@@ -190,3 +223,13 @@ func parseOptionalDecimal(s string) (decimal.Decimal, error) {
return parseDecimal(s) return parseDecimal(s)
} }
type splitRecord struct {
Record
ratio decimal.Decimal
}
func (sr splitRecord) Quantity() decimal.Decimal {
return sr.ratio
}

View File

@@ -30,7 +30,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
r: bytes.NewBufferString(`Market buy,2025-07-03 10:44:29,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.3690000000,USD,1.17995999,,"EUR",15.25,"EUR",0.25,"EUR",0.02,"EUR",,`), r: bytes.NewBufferString(`Market buy,2025-07-03 10:44:29,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.3690000000,USD,1.17995999,,"EUR",15.25,"EUR",0.25,"EUR",0.02,"EUR",,`),
want: Record{ want: Record{
symbol: "XX1234567890", symbol: "XX1234567890",
side: internal.KindBuy, kind: internal.KindBuy,
quantity: ShouldParseDecimal(t, "2.4387014200"), quantity: ShouldParseDecimal(t, "2.4387014200"),
price: ShouldParseDecimal(t, "7.3690000000"), price: ShouldParseDecimal(t, "7.3690000000"),
timestamp: time.Date(2025, 7, 3, 10, 44, 29, 0, time.UTC), timestamp: time.Date(2025, 7, 3, 10, 44, 29, 0, time.UTC),
@@ -44,7 +44,7 @@ func TestRecordReader_ReadRecord(t *testing.T) {
r: bytes.NewBufferString(`Market sell,2025-08-04 11:45:30,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.9999999999,USD,1.17995999,,"EUR",15.25,"EUR",,,0.02,"EUR",0.1,"EUR"`), r: bytes.NewBufferString(`Market sell,2025-08-04 11:45:30,XX1234567890,ABXY,"Aspargus Broccoli",EOF987654321,2.4387014200,7.9999999999,USD,1.17995999,,"EUR",15.25,"EUR",,,0.02,"EUR",0.1,"EUR"`),
want: Record{ want: Record{
symbol: "XX1234567890", symbol: "XX1234567890",
side: internal.KindSell, kind: internal.KindSell,
quantity: ShouldParseDecimal(t, "2.4387014200"), quantity: ShouldParseDecimal(t, "2.4387014200"),
price: ShouldParseDecimal(t, "7.9999999999"), price: ShouldParseDecimal(t, "7.9999999999"),
timestamp: time.Date(2025, 8, 4, 11, 45, 30, 0, time.UTC), timestamp: time.Date(2025, 8, 4, 11, 45, 30, 0, time.UTC),
@@ -123,8 +123,8 @@ func TestRecordReader_ReadRecord(t *testing.T) {
t.Fatalf("want symbol %v but got %v", tt.want.symbol, got.Symbol()) t.Fatalf("want symbol %v but got %v", tt.want.symbol, got.Symbol())
} }
if got.Kind() != tt.want.side { if got.Kind() != tt.want.kind {
t.Fatalf("want side %v but got %v", tt.want.side, got.Kind()) t.Fatalf("want side %v but got %v", tt.want.kind, got.Kind())
} }
if got.Price().Cmp(tt.want.price) != 0 { if got.Price().Cmp(tt.want.price) != 0 {
@@ -154,6 +154,70 @@ func TestRecordReader_ReadRecord(t *testing.T) {
} }
} }
func TestRecordReader_ReadRecord_Split(t *testing.T) {
// open row has the NEW (post-split) position: more shares at lower price
// close row has the OLD (pre-split) position: fewer shares at higher price
// ratio = openQty / closeQty = 0.5 / 0.1 = 5 (a 5:1 split)
splitOpen := `Stock split open,2025-06-03 05:34:16,XX1234567890,ABXY,"Aspargus Broccoli",EOF111111111,0.5000000000,20.0000000000,EUR,1.00000000,,,10.00,"EUR",,,,,,`
splitClose := `Stock split close,2025-06-03 05:34:16,XX1234567890,ABXY,"Aspargus Broccoli",EOF222222222,0.1000000000,100.0000000000,EUR,1.00000000,0.00,"EUR",10.00,"EUR",,,,,,`
t.Run("well-formed split pair returns split record with correct ratio", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitOpen+"\n"+splitClose),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
got, err := rr.ReadRecord(t.Context())
if err != nil {
t.Fatalf("ReadRecord() failed: %v", err)
}
if got.Kind() != internal.KindSplit {
t.Errorf("want kind %v but got %v", internal.KindSplit, got.Kind())
}
if got.Symbol() != "NO0013536151" {
t.Errorf("want symbol NO0013536151 but got %v", got.Symbol())
}
wantTimestamp := time.Date(2025, 6, 3, 5, 34, 16, 0, time.UTC)
if !got.Timestamp().Equal(wantTimestamp) {
t.Errorf("want timestamp %v but got %v", wantTimestamp, got.Timestamp())
}
// ratio = openQty / closeQty = 0.1245045 / 0.0249009 ≈ 5
openQty := ShouldParseDecimal(t, "0.1245045000")
closeQty := ShouldParseDecimal(t, "0.0249009000")
wantRatio := openQty.Div(closeQty)
if !got.Quantity().Equal(wantRatio) {
t.Errorf("want ratio %v but got %v", wantRatio, got.Quantity())
}
})
t.Run("close without prior open errors", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitClose),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
_, err := rr.ReadRecord(t.Context())
if err == nil {
t.Fatal("expected error but got none")
}
})
t.Run("two opens without close errors", func(t *testing.T) {
rr := NewRecordReader(
bytes.NewBufferString(splitOpen+"\n"+splitOpen),
NewFigiClientSecurityTypeStub(t, "Common Stock"),
)
_, err := rr.ReadRecord(t.Context())
if err == nil {
t.Fatal("expected error but got none")
}
})
}
func Test_figiNatureGetter(t *testing.T) { func Test_figiNatureGetter(t *testing.T) {
tests := []struct { tests := []struct {
name string // description of this test case name string // description of this test case