diff --git a/server/events/sse.go b/server/events/sse.go index a387f6347..dddff24b4 100644 --- a/server/events/sse.go +++ b/server/events/sse.go @@ -95,26 +95,35 @@ func (b *broker) prepareMessage(ctx context.Context, event Event) message { var errWriteTimeOut = errors.New("write timeout") -// writeEvent Write to the ResponseWriter, Server Sent Events compatible, and sends it -// right away, by flushing the writer (if it is a Flusher). It waits for the message to be flushed -// or times out after the specified timeout +// writeEvent writes a message to the given io.Writer, formatted as a Server-Sent Event. +// If the writer is an http.Flusher, it flushes the data immediately instead of buffering it. +// The function waits for the message to be written or times out after the specified timeout. func writeEvent(w io.Writer, event message, timeout time.Duration) (err error) { + // Create a context with a timeout based on the event's sender context. + ctx, cancel := context.WithTimeout(event.senderCtx, timeout) + defer cancel() + + // Create a channel to signal the completion of writing. complete := make(chan struct{}, 1) - flusher, ok := w.(http.Flusher) - if ok { - go func() { - _, err = fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data) - // Flush the data immediately instead of buffering it for later. + + // Start a goroutine to write the event and optionally flush the writer. + go func() { + _, err = fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data) + + // If the writer is an http.Flusher, flush the data immediately. + if flusher, ok := w.(http.Flusher); ok { flusher.Flush() - complete <- struct{}{} - }() - } else { + } + + // Signal that writing is complete. complete <- struct{}{} - } + }() + + // Wait for either the write completion or the context to time out. select { case <-complete: - return - case <-time.After(timeout): + return err + case <-ctx.Done(): return errWriteTimeOut } } diff --git a/server/events/sse_test.go b/server/events/sse_test.go index e6a44ca15..6ff8426b1 100644 --- a/server/events/sse_test.go +++ b/server/events/sse_test.go @@ -1,7 +1,12 @@ package events import ( + "bytes" "context" + "fmt" + "io" + "sync/atomic" + "time" "github.com/navidrome/navidrome/model/request" . "github.com/onsi/ginkgo/v2" @@ -58,4 +63,126 @@ var _ = Describe("Broker", func() { }) }) }) + + Describe("writeEvent", func() { + var ( + timeout time.Duration + buffer *bytes.Buffer + event message + senderCtx context.Context + cancel context.CancelFunc + ) + + BeforeEach(func() { + buffer = &bytes.Buffer{} + senderCtx, cancel = context.WithCancel(context.Background()) + DeferCleanup(cancel) + }) + + Context("with an HTTP flusher", func() { + var flusher *fakeFlusher + + BeforeEach(func() { + flusher = &fakeFlusher{Writer: buffer} + event = message{ + senderCtx: senderCtx, + id: 1, + event: "test", + data: "testdata", + } + }) + + Context("when the write completes before the timeout", func() { + BeforeEach(func() { + timeout = 1 * time.Second + }) + It("should successfully write the event", func() { + err := writeEvent(flusher, event, timeout) + Expect(err).NotTo(HaveOccurred()) + Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data))) + Expect(flusher.flushed.Load()).To(BeTrue()) + }) + }) + + Context("when the write does not complete before the timeout", func() { + BeforeEach(func() { + timeout = 1 * time.Millisecond + flusher.delay = 10 * time.Millisecond + }) + + It("should return an errWriteTimeOut error", func() { + err := writeEvent(flusher, event, timeout) + Expect(err).To(MatchError(errWriteTimeOut)) + Expect(flusher.flushed.Load()).To(BeFalse()) + }) + }) + + Context("without an HTTP flusher", func() { + var writer *fakeWriter + + BeforeEach(func() { + writer = &fakeWriter{Writer: buffer} + event = message{ + senderCtx: senderCtx, + id: 1, + event: "test", + data: "testdata", + } + }) + + Context("when the write completes before the timeout", func() { + BeforeEach(func() { + timeout = 1 * time.Second + }) + + It("should successfully write the event", func() { + err := writeEvent(writer, event, timeout) + Expect(err).NotTo(HaveOccurred()) + Eventually(writer.done.Load).Should(BeTrue()) + Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data))) + }) + }) + + Context("when the write does not complete before the timeout", func() { + BeforeEach(func() { + timeout = 1 * time.Millisecond + writer.delay = 10 * time.Millisecond + }) + + It("should return an errWriteTimeOut error", func() { + err := writeEvent(writer, event, timeout) + Expect(err).To(MatchError(errWriteTimeOut)) + Expect(writer.done.Load()).To(BeFalse()) + }) + }) + }) + }) + }) }) + +type fakeWriter struct { + io.Writer + delay time.Duration + done atomic.Bool +} + +func (f *fakeWriter) Write(p []byte) (n int, err error) { + time.Sleep(f.delay) + f.done.Store(true) + return f.Writer.Write(p) +} + +type fakeFlusher struct { + io.Writer + delay time.Duration + flushed atomic.Bool +} + +func (f *fakeFlusher) Write(p []byte) (n int, err error) { + time.Sleep(f.delay) + return f.Writer.Write(p) +} + +func (f *fakeFlusher) Flush() { + f.flushed.Store(true) +}