From a394dab041f0ae09de10f1ac5d2f8488aa474fdb Mon Sep 17 00:00:00 2001 From: Natercio Moniz Date: Mon, 9 Sep 2024 11:22:00 +0100 Subject: [PATCH] major refactor to accommodate feed --- async.go | 125 +++++++++-------------------------------- async_bench_test.go | 28 +--------- async_test.go | 133 +++++++++++++------------------------------- feed.go | 37 ++++++++++++ feed_test.go | 62 +++++++++++++++++++++ options.go | 24 ++++++++ sync.go | 40 ++++++++++++- wrappers.go | 79 ++++++++++++++++++++++++++ wrappers_test.go | 55 ++++++++++++++++++ 9 files changed, 360 insertions(+), 223 deletions(-) create mode 100644 feed.go create mode 100644 feed_test.go create mode 100644 options.go create mode 100644 wrappers_test.go diff --git a/async.go b/async.go index d1d2c51..c1fb873 100644 --- a/async.go +++ b/async.go @@ -1,9 +1,7 @@ package gubgub import ( - "context" "fmt" - "iter" "sync" ) @@ -15,10 +13,11 @@ import ( // that some might never actually receive any message. Subscriber registration order is also not // guaranteed. type AsyncTopic[T any] struct { - options AsyncTopicOptions + options TopicOptions - mu sync.RWMutex - closed bool + mu sync.RWMutex + closing bool + closed chan struct{} publishCh chan T subscribeCh chan Subscriber[T] @@ -26,8 +25,8 @@ type AsyncTopic[T any] struct { // NewAsyncTopic creates an AsyncTopic that will be closed when the given context is cancelled. // After closed calls to Publish or Subscribe will return an error. -func NewAsyncTopic[T any](ctx context.Context, opts ...AsyncTopicOption) *AsyncTopic[T] { - options := AsyncTopicOptions{ +func NewAsyncTopic[T any](opts ...TopicOption) *AsyncTopic[T] { + options := TopicOptions{ onClose: func() {}, // Called after the Topic is closed and all messages have been delivered. onSubscribe: func(count int) {}, // Called everytime a new subscriber is added } @@ -38,34 +37,46 @@ func NewAsyncTopic[T any](ctx context.Context, opts ...AsyncTopicOption) *AsyncT t := AsyncTopic[T]{ options: options, + closed: make(chan struct{}), publishCh: make(chan T, 1), subscribeCh: make(chan Subscriber[T], 1), } - go t.closer(ctx) go t.run() return &t } -func (t *AsyncTopic[T]) closer(ctx context.Context) { - <-ctx.Done() +// Close terminates background go routines and prevents further publishing and subscribing. All +// published messages are garanteed to be delivered once Close returns. This is idempotent. +func (t *AsyncTopic[T]) Close() { + t.mu.RLock() + closing := t.closing + t.mu.RUnlock() + + if closing { + // It's either already closed or it's closing. + return + } t.mu.Lock() - t.closed = true // no more subscribing or publishing + t.closing = true // no more subscribing or publishing t.mu.Unlock() close(t.publishCh) close(t.subscribeCh) + + <-t.closed } func (t *AsyncTopic[T]) run() { + defer close(t.closed) defer t.options.onClose() var subscribers []Subscriber[T] defer func() { - // There is only one way to get here: the topic is now closed! + // There is only one way to get here: the topic is now closing! // Because both `subscribeCh` and `publishCh` channels are closed when the topic is closed // this will always eventually return. // This will deliver any potential queued message thus fulfilling the message delivery @@ -99,7 +110,7 @@ func (t *AsyncTopic[T]) run() { func (t *AsyncTopic[T]) Publish(msg T) error { t.mu.RLock() - if t.closed { + if t.closing { t.mu.RUnlock() return fmt.Errorf("async topic publish: %w", ErrTopicClosed) } @@ -116,7 +127,7 @@ func (t *AsyncTopic[T]) Publish(msg T) error { func (t *AsyncTopic[T]) Subscribe(fn Subscriber[T]) error { t.mu.RLock() - if t.closed { + if t.closing { t.mu.RUnlock() return fmt.Errorf("async topic subscribe: %w", ErrTopicClosed) } @@ -128,89 +139,3 @@ func (t *AsyncTopic[T]) Subscribe(fn Subscriber[T]) error { return nil } - -// Feed allows the usage of for/range to consume future published messages. -// The supporting subscriber will eventually be discarded after you exit the for loop. -func (t *AsyncTopic[T]) Feed() iter.Seq[T] { - feed := make(chan T, 1) // closed by the middleman go routine - messages := make(chan T, 1) // closed by the subscriber - yieldReady := make(chan struct{}) // closed by the iterator - unsubscribe := make(chan struct{}) // closed by the iterator - - t.Subscribe(func(msg T) bool { - select { - case messages <- msg: - return true - case <-unsubscribe: - close(messages) - return false - } - }) - - go func() { - defer close(feed) - - q := make([]T, 0, 1) - waiting := false - - for { - select { - case m, more := <-messages: - if !more { - return - } - - if waiting { - waiting = false - feed <- m - } else { - q = append(q, m) - } - - case _, more := <-yieldReady: - if !more { - return - } - - if len(q) > 0 { - waiting = false - feed <- q[0] - q = q[1:] - } else { - waiting = true - } - } - } - }() - - return func(yield func(T) bool) { - defer close(unsubscribe) - defer close(yieldReady) - - for { - yieldReady <- struct{}{} - if !yield(<-feed) { - return - } - } - } -} - -type AsyncTopicOptions struct { - onClose func() - onSubscribe func(count int) -} - -type AsyncTopicOption func(*AsyncTopicOptions) - -func WithOnClose(fn func()) AsyncTopicOption { - return func(opts *AsyncTopicOptions) { - opts.onClose = fn - } -} - -func WithOnSubscribe(fn func(count int)) AsyncTopicOption { - return func(opts *AsyncTopicOptions) { - opts.onSubscribe = fn - } -} diff --git a/async_bench_test.go b/async_bench_test.go index 5860408..a8defe2 100644 --- a/async_bench_test.go +++ b/async_bench_test.go @@ -1,7 +1,6 @@ package gubgub import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -10,26 +9,8 @@ import ( func BenchmarkAsyncTopic_Publish(b *testing.B) { for _, tc := range publishCases { b.Run(tc.Name, func(b *testing.B) { - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - subscribersReady := make(chan struct{}, 1) - defer close(subscribersReady) - - topicClosed := make(chan struct{}, 1) - defer close(topicClosed) - - topic := NewAsyncTopic[int](ctx, - WithOnSubscribe(func(count int) { - if count == tc.Count { - subscribersReady <- struct{}{} - } - }), - WithOnClose(func() { - topicClosed <- struct{}{} - }), - ) + onSubscribe, subscribersReady := withNotifyOnNthSubscriber(b, int64(tc.Count)) + topic := newTestAsyncTopic[int](b, onSubscribe) for range tc.Count { require.NoError(b, topic.Subscribe(tc.Subscriber)) @@ -45,10 +26,7 @@ func BenchmarkAsyncTopic_Publish(b *testing.B) { b.StopTimer() - cancel() - - // This just helps leaving as few running Go routines as possible when the next round starts - <-topicClosed + topic.Close() }) } diff --git a/async_test.go b/async_test.go index 52fba99..5da75d5 100644 --- a/async_test.go +++ b/async_test.go @@ -1,8 +1,7 @@ package gubgub import ( - "context" - "sync" + "sync/atomic" "testing" "time" @@ -13,12 +12,8 @@ import ( func TestAsyncTopic_SinglePublisherSingleSubscriber(t *testing.T) { const msgCount = 10 - subscriberReady := make(chan struct{}, 1) - defer close(subscriberReady) - - topic := NewAsyncTopic[int](context.Background(), WithOnSubscribe(func(count int) { - subscriberReady <- struct{}{} - })) + onSubscribe, subscriberReady := withNotifyOnNthSubscriber(t, 1) + topic := newTestAsyncTopic[int](t, onSubscribe) feedback := make(chan struct{}, msgCount) defer close(feedback) @@ -56,17 +51,8 @@ func TestAsyncTopic_MultiPublishersMultiSubscribers(t *testing.T) { msgCount = pubCount * 100 // total messages to publish (delivered to EACH subscriber) ) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - subscribersReady := make(chan struct{}, 1) - defer close(subscribersReady) - - topic := NewAsyncTopic[int](ctx, WithOnSubscribe(func(count int) { - if count == subCount { - subscribersReady <- struct{}{} - } - })) + onSubscribe, subscribersReady := withNotifyOnNthSubscriber(t, subCount) + topic := newTestAsyncTopic[int](t, onSubscribe) expFeedbackCount := msgCount * subCount feedback := make(chan int, expFeedbackCount) @@ -111,15 +97,12 @@ func TestAsyncTopic_MultiPublishersMultiSubscribers(t *testing.T) { } func TestAsyncTopic_WithOnClose(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - feedback := make(chan struct{}, 1) defer close(feedback) - _ = NewAsyncTopic[int](ctx, WithOnClose(func() { feedback <- struct{}{} })) + topic := NewAsyncTopic[int](WithOnClose(func() { feedback <- struct{}{} })) - cancel() + topic.Close() select { case <-feedback: @@ -133,13 +116,10 @@ func TestAsyncTopic_WithOnClose(t *testing.T) { func TestAsyncTopic_WithOnSubscribe(t *testing.T) { const totalSub = 10 - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - feedback := make(chan int, totalSub) defer close(feedback) - topic := NewAsyncTopic[int](ctx, WithOnSubscribe(func(count int) { feedback <- count })) + topic := NewAsyncTopic[int](WithOnSubscribe(func(count int) { feedback <- count })) for range totalSub { topic.Subscribe(func(i int) bool { return true }) @@ -180,23 +160,12 @@ func TestAsyncTopic_ClosedTopicError(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - feedback := make(chan struct{}, 1) defer close(feedback) - topic := NewAsyncTopic[int](ctx, WithOnClose(func() { feedback <- struct{}{} })) + topic := NewAsyncTopic[int]() - cancel() // this should close the topic, no more messages can be published - - select { - case <-feedback: - break - - case <-testTimer(t, time.Second).C: - t.Fatalf("expected feedback by now") - } + topic.Close() // this should close the topic, no more messages can be published tc.assertFn(topic) }) @@ -206,17 +175,8 @@ func TestAsyncTopic_ClosedTopicError(t *testing.T) { func TestAsyncTopic_AllPublishedBeforeClosedAreDeliveredAfterClosed(t *testing.T) { const msgCount = 10 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - subscriberReady := make(chan struct{}, 1) - defer close(subscriberReady) - - topic := NewAsyncTopic[int](ctx, - WithOnSubscribe(func(count int) { - subscriberReady <- struct{}{} - }), - ) + onSubscribe, subscriberReady := withNotifyOnNthSubscriber(t, 1) + topic := newTestAsyncTopic[int](t, onSubscribe) feedback := make(chan int) // unbuffered will cause choke point for publishers defer close(feedback) @@ -233,7 +193,7 @@ func TestAsyncTopic_AllPublishedBeforeClosedAreDeliveredAfterClosed(t *testing.T require.NoError(t, topic.Publish(i)) } - cancel() + go topic.Close() values := make(map[int]struct{}, msgCount) timeout := testTimer(t, time.Second) @@ -249,48 +209,6 @@ func TestAsyncTopic_AllPublishedBeforeClosedAreDeliveredAfterClosed(t *testing.T } } -func TestAsyncTopic_Feed(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - subscriberReady := make(chan struct{}, 1) - defer close(subscriberReady) - - topic := NewAsyncTopic[int](ctx, - WithOnSubscribe(func(count int) { - subscriberReady <- struct{}{} - }), - ) - - wg := sync.WaitGroup{} - - wg.Add(1) - go func() { - defer wg.Done() - - seen := make(map[int]struct{}) - for i := range topic.Feed() { - seen[i] = struct{}{} - if len(seen) >= 9 { - return - } - } - }() - - <-subscriberReady - - wg.Add(1) - go func() { - defer wg.Done() - - for i := range 10 { - topic.Publish(i) - } - }() - - wg.Wait() -} - func testTimer(t testing.TB, d time.Duration) *time.Timer { t.Helper() @@ -301,3 +219,28 @@ func testTimer(t testing.TB, d time.Duration) *time.Timer { return timer } + +func newTestAsyncTopic[T any](t testing.TB, opts ...TopicOption) *AsyncTopic[T] { + t.Helper() + topic := NewAsyncTopic[T](opts...) + t.Cleanup(topic.Close) + return topic +} + +func withNotifyOnNthSubscriber(t testing.TB, n int64) (TopicOption, <-chan struct{}) { + t.Helper() + + notify := make(chan struct{}, 1) + t.Cleanup(func() { + close(notify) + }) + + var counter atomic.Int64 + + return WithOnSubscribe(func(count int) { + c := counter.Add(1) + if c == n { + notify <- struct{}{} + } + }), notify +} diff --git a/feed.go b/feed.go new file mode 100644 index 0000000..4165038 --- /dev/null +++ b/feed.go @@ -0,0 +1,37 @@ +package gubgub + +import "iter" + +// Feed allows the usage of for/range to consume future published messages. +// The supporting subscriber will eventually be discarded after you exit the for loop. +func Feed[T any](t Subscribable[T], buffered bool) iter.Seq[T] { + feed := make(chan T) // closed by the subscriber + unsubscribe := make(chan struct{}) // closed by the iterator + + subscriber := func(msg T) bool { + select { + case feed <- msg: + return true + case <-unsubscribe: + close(feed) + return false + } + } + + if buffered { + subscriber = Buffered(subscriber) + } + + t.Subscribe(subscriber) + + // Iterator + return func(yield func(T) bool) { + defer close(unsubscribe) + + for msg := range feed { + if !yield(msg) { + return + } + } + } +} diff --git a/feed_test.go b/feed_test.go new file mode 100644 index 0000000..6494a64 --- /dev/null +++ b/feed_test.go @@ -0,0 +1,62 @@ +package gubgub + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFeed_Topics(t *testing.T) { + const msgCount = 10 + + subscriberReady := make(chan struct{}, 1) + defer close(subscriberReady) + + onSubscribe := WithOnSubscribe(func(count int) { + subscriberReady <- struct{}{} + }) + + testCases := []struct { + name string + topic Topic[int] + }{ + { + name: "sync topic", + topic: NewSyncTopic[int](onSubscribe), + }, + { + name: "async topic", + topic: NewAsyncTopic[int](onSubscribe), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + feedback := make(chan int) + go func() { + for i := range Feed(tc.topic, false) { + feedback <- i + } + }() + + go func() { + <-subscriberReady + for i := range msgCount { + require.NoError(t, tc.topic.Publish(i)) + } + }() + + var counter int + timeout := testTimer(t, time.Second) + for counter < msgCount { + select { + case <-feedback: + counter++ + case <-timeout.C: + t.Fatalf("expected %d feedback values by now but only got %d", msgCount, counter) + } + } + }) + } + +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..eb8201e --- /dev/null +++ b/options.go @@ -0,0 +1,24 @@ +package gubgub + +// TopicOptions holds common options for topics. +type TopicOptions struct { + // onClose is called after the Topic is closed and all messages have been delivered. + onClose func() + + // onSubscribe is called after a new subscriber is regitered. + onSubscribe func(count int) +} + +type TopicOption func(*TopicOptions) + +func WithOnClose(fn func()) TopicOption { + return func(opts *TopicOptions) { + opts.onClose = fn + } +} + +func WithOnSubscribe(fn func(count int)) TopicOption { + return func(opts *TopicOptions) { + opts.onSubscribe = fn + } +} diff --git a/sync.go b/sync.go index 45dce80..6dd426d 100644 --- a/sync.go +++ b/sync.go @@ -1,21 +1,50 @@ package gubgub -import "sync" +import ( + "fmt" + "sync" + "sync/atomic" +) // SyncTopic is the simplest and most naive topic. It allows any message T to be broadcast to // subscribers. Publishing and Subscribing happens synchronously (block). type SyncTopic[T any] struct { + options TopicOptions + + closed atomic.Bool + mu sync.Mutex subscribers []Subscriber[T] } // NewSyncTopic creates a zero SyncTopic and return a pointer to it. -func NewSyncTopic[T any]() *SyncTopic[T] { - return &SyncTopic[T]{} +func NewSyncTopic[T any](opts ...TopicOption) *SyncTopic[T] { + options := TopicOptions{ + onClose: func() {}, + onSubscribe: func(count int) {}, + } + + for _, opt := range opts { + opt(&options) + } + + return &SyncTopic[T]{ + options: options, + } +} + +// Close will cause future Publish and Subscribe calls to return an error. +func (t *SyncTopic[T]) Close() { + t.closed.Store(true) + t.options.onClose() } // Publish broadcasts a message to all subscribers. func (t *SyncTopic[T]) Publish(msg T) error { + if t.closed.Load() { + return fmt.Errorf("sync topic publish: %w", ErrTopicClosed) + } + t.mu.Lock() defer t.mu.Unlock() @@ -26,10 +55,15 @@ func (t *SyncTopic[T]) Publish(msg T) error { // Subscribe adds a Subscriber func that will consume future published messages. func (t *SyncTopic[T]) Subscribe(fn Subscriber[T]) error { + if t.closed.Load() { + return fmt.Errorf("sync topic subscribe: %w", ErrTopicClosed) + } + t.mu.Lock() defer t.mu.Unlock() t.subscribers = append(t.subscribers, fn) + t.options.onSubscribe(len(t.subscribers)) return nil } diff --git a/wrappers.go b/wrappers.go index 944e119..df71a4a 100644 --- a/wrappers.go +++ b/wrappers.go @@ -23,3 +23,82 @@ func Once[T any](fn func(T)) Subscriber[T] { func NoOp[T any]() Subscriber[T] { return func(_ T) bool { return true } } + +// Buffered returns a subscriber that buffers messages if they can't be delivered immediately. +// There is no artificial limit to how many items can be buffered. This is bounded only by +// available memory. +// This is useful if message publishing is surge prone and message processing is slow or +// unpredictable (for example: subscriber makes network request). +// Message average processing rate must still be higher than the average message publishing rate +// otherwise it will eventually lead to memory issues. You will need to find a better strategy to +// deal with such scenario. +func Buffered[T any](subscriber Subscriber[T]) Subscriber[T] { + unsubscribe := make(chan struct{}) // closed by the worker + ready := make(chan struct{}) // closed by the worker + messages := make(chan T) // closed by the forwarder + work := make(chan T) // closed by the middleman + + // Worker calls the actual subscriber. It notifies the middleman that it's ready for the next + // message via the ready channel and then reads from the work channel. + go func() { + for w := range work { + if !subscriber(w) { + close(unsubscribe) + close(ready) + return + } + ready <- struct{}{} + } + }() + + // Middleman that handles buffering. When the worker notifies that it is ready for the next + // message it will check if there is buffered messages and push the next one immediately or + // else push it when the next message arrives. + go func() { + defer close(work) + + idling := true // so that the first message can go straight to the consumer + + q := make([]T, 0, 1) + + for { + select { + case msg, more := <-messages: + if !more { + return + } + + if idling { + idling = false + work <- msg + } else { + q = append(q, msg) + } + + case _, more := <-ready: + if !more { + return + } + + if len(q) > 0 { + work <- q[0] + q = q[1:] + } else { + idling = true + } + } + + } + }() + + // forwarder just sends messages to the middleman or quits. + return func(msg T) bool { + select { + case messages <- msg: + return true + case <-unsubscribe: + close(messages) + return false + } + } +} diff --git a/wrappers_test.go b/wrappers_test.go new file mode 100644 index 0000000..25ee413 --- /dev/null +++ b/wrappers_test.go @@ -0,0 +1,55 @@ +package gubgub + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBuffered_Once(t *testing.T) { + feedback := make(chan int, 1) + s := Buffered(Once(func(i int) { + feedback <- i // buffered channel means no blocking + })) + + assert.True(t, s(1234)) + + timeout := testTimer(t, time.Second) + + select { + case i := <-feedback: + assert.Equal(t, 1234, i) + + case <-timeout.C: + t.Fatalf("expected feedback value by now") + } + + assert.False(t, s(4321)) +} + +func TestBuffered_Forever(t *testing.T) { + const msgCount = 100 + + feedback := make(chan int) + s := Buffered(Forever(func(i int) { + feedback <- i // unbuffered channel creates choke point (blocks) to force buffering + })) + + for i := range msgCount { + assert.True(t, s(i)) + } + + timeout := testTimer(t, time.Second) + + var count int + for count < msgCount { + select { + case <-feedback: + count++ + + case <-timeout.C: + t.Fatalf("expected %d feedback values by now", msgCount) + } + } +}