ajhahn.de
← Flash
Flash 231 lines
// transport — LSP base-protocol framing: Content-Length headers over a
// byte stream.
//
// The Language Server Protocol wraps every JSON-RPC message in an
// HTTP-style header block — `Content-Length: N\r\n`, optionally other
// headers, then `\r\n` and exactly N body bytes. This module is the pure
// half of the server's stdio loop: `scan` recognizes one complete frame
// in a byte buffer, `frame` wraps a body for writing, and `Decoder`
// accumulates arbitrary read chunks and hands out complete bodies in
// order. Nothing here touches a file descriptor — the tests feed byte
// buffers, and the server owns the actual stdin/stdout plumbing.
//
// Per the spec, header names compare case-insensitively and unknown
// headers (`Content-Type`) are skipped. A complete header block without
// a parseable Content-Length is `error.Malformed` — with the framing
// gone, the stream can never resynchronize, so the server's only move
// is to report and exit.

use std
use core

pub const Error = error{ Malformed, OutOfMemory }

// One recognized frame: the body spans `[body_start, body_start+body_len)`
// and the whole frame — headers included — occupies the first `total`
// bytes of the scanned buffer.
pub const Scan = struct {
    body_start usize,
    body_len usize,
    total usize,
}

// Recognize one complete frame at the start of `buf`. Returns null while
// the frame is still incomplete (no header terminator yet, or fewer than
// Content-Length body bytes buffered) — feed more bytes and rescan.
pub fn scan(buf []u8) Error!?Scan {
    term := core.mem.indexOf(u8, buf, "\r\n\r\n") orelse return null
    var body_len ?usize = null
    var pos usize = 0
    while pos < term {
        var line_end = term
        if core.mem.indexOf(u8, buf[pos..term], "\r\n") |rel| {
            line_end = pos + rel
        }
        line := buf[pos..line_end]
        colon := core.mem.indexOfScalar(u8, line, ':') orelse return error.Malformed
        if asciiEqlNoCase(line[0..colon], "content-length") {
            v := trim(line[colon + 1 ..])
            body_len = core.fmt.parseInt(usize, v, 10) catch return error.Malformed
        }
        pos = line_end + 2
    }
    n := body_len orelse return error.Malformed
    body_start := term + 4
    if buf.len < body_start + n {
        return null
    }
    return Scan{ .body_start = body_start, .body_len = n, .total = body_start + n }
}

// Wrap `body` in its header, freshly allocated: the exact bytes to write
// to the protocol channel.
pub fn frame(alloc std.mem.Allocator, body []u8) Error![]u8 {
    return core.fmt.allocPrint(alloc, "Content-Length: {d}\r\n\r\n{s}", .{ body.len, body })
}

// The accumulating reader side: `feed` appends whatever chunk arrived,
// `next` hands out the body of the first complete frame. The returned
// slice points into the decoder's buffer and stays valid until the next
// `feed` — handle the message (or copy it out) before reading again.
pub const Decoder = struct {
    // The allocation; live bytes are `data[start..len]`. `start` marks
    // bytes already handed out by `next`, reclaimed on the next `feed`.
    data []mut u8,
    len usize,
    start usize,

    pub const empty Decoder = .{ .data = &.{}, .len = 0, .start = 0 }

    pub fn deinit(self *mut Decoder, alloc std.mem.Allocator) void {
        alloc.free(self.data)
        self.* = .empty
    }

    // Append a read chunk, compacting consumed bytes away first so the
    // buffer never grows past one frame plus one read.
    pub fn feed(self *mut Decoder, alloc std.mem.Allocator, bytes []u8) Error!void {
        if self.start > 0 {
            n := self.len - self.start
            core.mem.copy(u8, self.data[0..n], self.data[self.start..self.len])
            self.len = n
            self.start = 0
        }
        needed := self.len + bytes.len
        if needed > self.data.len {
            var cap usize = self.data.len * 2
            if cap < 64 {
                cap = 64
            }
            if cap < needed {
                cap = needed
            }
            grown := try alloc.alloc(u8, cap)
            core.mem.copy(u8, grown[0..self.len], self.data[0..self.len])
            alloc.free(self.data)
            self.data = grown
        }
        core.mem.copy(u8, self.data[self.len .. self.len + bytes.len], bytes)
        self.len += bytes.len
    }

    // The body of the next complete frame, or null when more bytes are
    // needed. Valid until the next `feed`.
    pub fn next(self *mut Decoder) Error!?[]u8 {
        live := self.data[self.start..self.len]
        s := (try scan(live)) orelse return null
        body := live[s.body_start .. s.body_start + s.body_len]
        self.start += s.total
        return body
    }
}

fn asciiEqlNoCase(a []u8, b []u8) bool {
    if a.len != b.len {
        return false
    }
    for c, i in a {
        if asciiLower(c) != asciiLower(b[i]) {
            return false
        }
    }
    return true
}

fn asciiLower(c u8) u8 {
    if c >= 'A' && c <= 'Z' {
        return c + 32
    }
    return c
}

fn trim(s []u8) []u8 {
    var lo usize = 0
    var hi usize = s.len
    while lo < hi && (s[lo] == ' ' || s[lo] == '\t') {
        lo += 1
    }
    while hi > lo && (s[hi - 1] == ' ' || s[hi - 1] == '\t') {
        hi -= 1
    }
    return s[lo..hi]
}

test "scan recognizes a complete frame" {
    s := (try scan("Content-Length: 2\r\n\r\nhi")).?
    try std.testing.expectEqual(21, s.body_start)
    try std.testing.expectEqual(2, s.body_len)
    try std.testing.expectEqual(23, s.total)
}

test "scan waits for the header terminator and the full body" {
    try std.testing.expect((try scan("")) == null)
    try std.testing.expect((try scan("Content-Length: 2")) == null)
    try std.testing.expect((try scan("Content-Length: 2\r\n\r\n")) == null)
    try std.testing.expect((try scan("Content-Length: 2\r\n\r\nh")) == null)
}

test "scan skips unknown headers and compares names case-insensitively" {
    msg := "content-LENGTH: 4\r\nContent-Type: application/vscode-jsonrpc; charset=utf-8\r\n\r\nbody"
    s := (try scan(msg)).?
    try std.testing.expectEqual(4, s.body_len)
    try std.testing.expectEqual(msg.len, s.total)
}

test "scan rejects a header block it cannot frame by" {
    try std.testing.expectError(error.Malformed, scan("Content-Type: text\r\n\r\nxx"))
    try std.testing.expectError(error.Malformed, scan("Content-Length: ten\r\n\r\nxx"))
    try std.testing.expectError(error.Malformed, scan("no colon here\r\n\r\nxx"))
}

test "frame writes the exact bytes scan reads back" {
    var a = core.arena.ArenaAllocator.init(std.testing.allocator)
    defer a.deinit()
    f := try frame(a.allocator(), "{\"id\":1}")
    try std.testing.expect(core.mem.eql(u8, f, "Content-Length: 8\r\n\r\n{\"id\":1}"))
    s := (try scan(f)).?
    try std.testing.expect(core.mem.eql(u8, f[s.body_start .. s.body_start + s.body_len], "{\"id\":1}"))
}

test "a decoder fed one byte at a time yields every message in order" {
    var d Decoder = .empty
    defer d.deinit(std.testing.allocator)
    stream := "Content-Length: 3\r\n\r\noneContent-Length: 3\r\n\r\ntwo"
    var got usize = 0
    for c in stream {
        one := [1]u8{c}
        try d.feed(std.testing.allocator, one[0..])
        while true {
            body := (try d.next()) orelse break
            if got == 0 {
                try std.testing.expect(core.mem.eql(u8, body, "one"))
            } else {
                try std.testing.expect(core.mem.eql(u8, body, "two"))
            }
            got += 1
        }
    }
    try std.testing.expectEqual(2, got)
}

test "a decoder hands out back-to-back frames from one feed" {
    var d Decoder = .empty
    defer d.deinit(std.testing.allocator)
    try d.feed(std.testing.allocator, "Content-Length: 1\r\n\r\naContent-Length: 1\r\n\r\nb")
    try std.testing.expect(core.mem.eql(u8, (try d.next()).?, "a"))
    try std.testing.expect(core.mem.eql(u8, (try d.next()).?, "b"))
    try std.testing.expect((try d.next()) == null)
}

fn feedSweep(alloc std.mem.Allocator) !void {
    var d Decoder = .empty
    defer d.deinit(alloc)
    try d.feed(alloc, "Content-Length: 3\r\n\r\nxyz")
    body := (try d.next()).?
    try std.testing.expect(core.mem.eql(u8, body, "xyz"))
}

test "feed survives failure at every allocation point" {
    try std.testing.checkAllAllocationFailures(std.testing.allocator, feedSweep, .{})
}