Commit 7cf2cbb33e

Andrew Kelley <andrew@ziglang.org>
2023-05-18 05:39:12
std.crypto.tls.Client.readvAdvanced: fix bugs
* When there is buffered cleartext, return it without calling the underlying read function. This prevents buffer overflow due to space used up by cleartext. * Avoid clearing the buffer when the buffered cleartext could not be completely given to the result read buffer, and there is some buffered ciphertext left. * Instead of rounding up the amount of bytes to ask for to the nearest TLS record size, round down, with a minimum of 1. This prevents the code path from being taken which requires extra memory copies. * Avoid calling `@memcpy` with overlapping arguments. closes #15590
1 parent 378264d
Changed files (1)
lib
std
crypto
lib/std/crypto/tls/Client.zig
@@ -924,7 +924,9 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
         const amt = @intCast(u15, vp.put(partial_cleartext));
         c.partial_cleartext_idx += amt;
 
-        if (c.partial_ciphertext_end == c.partial_ciphertext_idx) {
+        if (c.partial_cleartext_idx == c.partial_ciphertext_idx and
+            c.partial_ciphertext_end == c.partial_ciphertext_idx)
+        {
             // The buffer is now empty.
             c.partial_cleartext_idx = 0;
             c.partial_ciphertext_idx = 0;
@@ -935,7 +937,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
             c.partial_ciphertext_end = 0;
             assert(vp.total == amt);
             return amt;
-        } else if (amt <= partial_cleartext.len) {
+        } else if (amt > 0) {
             // We don't need more data, so don't call read.
             assert(vp.total == amt);
             return amt;
@@ -970,8 +972,8 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
         },
     };
 
-    // Cleartext capacity of output buffer, in records, rounded up.
-    const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
+    // Cleartext capacity of output buffer, in records. Minimum one full record.
+    const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1);
     const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
     const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
     const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
@@ -1029,7 +1031,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
             if (frag1.len < second_len)
                 return finishRead2(c, first, frag1, vp.total);
 
-            @memcpy(frag[0..in], first);
+            limitedOverlapCopy(frag, in);
             @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
             frag = frag[0..full_record_len];
             frag1 = frag1[second_len..];
@@ -1059,7 +1061,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
             if (frag1.len < second_len)
                 return finishRead2(c, first, frag1, vp.total);
 
-            @memcpy(frag[0..in], first);
+            limitedOverlapCopy(frag, in);
             @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
             frag = frag[0..full_record_len];
             frag1 = frag1[second_len..];
@@ -1176,8 +1178,10 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec)
                             if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
                                 // We have already run out of room in iovecs. Continue
                                 // appending to `partially_read_buffer`.
-                                const dest = c.partially_read_buffer[c.partial_ciphertext_idx..];
-                                @memcpy(dest[0..msg.len], msg);
+                                @memcpy(
+                                    c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len],
+                                    msg,
+                                );
                                 c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len);
                             } else {
                                 const amt = vp.put(msg);
@@ -1223,22 +1227,38 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
     return out;
 }
 
+/// Note that `first` usually overlaps with `c.partially_read_buffer`.
 fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
     if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
         // There is cleartext at the beginning already which we need to preserve.
         c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + first.len + frag1.len);
-        @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first);
+        // TODO: eliminate this call to copyForwards
+        std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first);
         @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1);
     } else {
         c.partial_cleartext_idx = 0;
         c.partial_ciphertext_idx = 0;
         c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), first.len + frag1.len);
+        // TODO: eliminate this call to copyForwards
         std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first);
         @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1);
     }
     return out;
 }
 
+fn limitedOverlapCopy(frag: []u8, in: usize) void {
+    const first = frag[in..];
+    if (first.len <= in) {
+        // A single, non-overlapping memcpy suffices.
+        @memcpy(frag[0..first.len], first);
+    } else {
+        // Need two memcpy calls because one alone would overlap.
+        @memcpy(frag[0..in], first[0..in]);
+        const leftover = first.len - in;
+        @memcpy(frag[in..][0..leftover], first[in..][0..leftover]);
+    }
+}
+
 fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
     if (index < s1.len) {
         return s1[index];