master
  1//! A buffered DER encoder.
  2//!
  3//! Prefers calling container's `fn encodeDer(self: @This(), encoder: *der.Encoder)`.
  4//! That function should encode values, lengths, then tags.
  5buffer: ArrayListReverse,
  6/// The field tag set by a parent container.
  7/// This is needed because we might visit an implicitly tagged container with a `fn encodeDer`.
  8field_tag: ?FieldTag = null,
  9
 10pub fn init(allocator: std.mem.Allocator) Encoder {
 11    return Encoder{ .buffer = ArrayListReverse.init(allocator) };
 12}
 13
 14pub fn deinit(self: *Encoder) void {
 15    self.buffer.deinit();
 16}
 17
 18/// Encode any value.
 19pub fn any(self: *Encoder, val: anytype) !void {
 20    const T = @TypeOf(val);
 21    try self.anyTag(Tag.fromZig(T), val);
 22}
 23
 24fn anyTag(self: *Encoder, tag_: Tag, val: anytype) !void {
 25    const T = @TypeOf(val);
 26    if (std.meta.hasFn(T, "encodeDer")) return try val.encodeDer(self);
 27    const start = self.buffer.data.len;
 28    const merged_tag = self.mergedTag(tag_);
 29
 30    switch (@typeInfo(T)) {
 31        .@"struct" => |info| {
 32            inline for (0..info.fields.len) |i| {
 33                const f = info.fields[info.fields.len - i - 1];
 34                const field_val = @field(val, f.name);
 35                const field_tag = FieldTag.fromContainer(T, f.name);
 36
 37                // > The encoding of a set value or sequence value shall not include an encoding for any
 38                // > component value which is equal to its default value.
 39                const is_default = if (f.is_comptime) false else if (f.default_value_ptr) |v| brk: {
 40                    const default_val: *const f.type = @ptrCast(@alignCast(v));
 41                    break :brk std.mem.eql(u8, std.mem.asBytes(default_val), std.mem.asBytes(&field_val));
 42                } else false;
 43
 44                if (!is_default) {
 45                    const start2 = self.buffer.data.len;
 46                    self.field_tag = field_tag;
 47                    // will merge with self.field_tag.
 48                    // may mutate self.field_tag.
 49                    try self.anyTag(Tag.fromZig(f.type), field_val);
 50                    if (field_tag) |ft| {
 51                        if (ft.explicit) {
 52                            try self.length(self.buffer.data.len - start2);
 53                            try self.tag(ft.toTag());
 54                            self.field_tag = null;
 55                        }
 56                    }
 57                }
 58            }
 59        },
 60        .bool => try self.buffer.prependSlice(&[_]u8{if (val) 0xff else 0}),
 61        .int => try self.int(T, val),
 62        .@"enum" => |e| {
 63            if (@hasDecl(T, "oids")) {
 64                return self.any(T.oids.enumToOid(val));
 65            } else {
 66                try self.int(e.tag_type, @intFromEnum(val));
 67            }
 68        },
 69        .optional => if (val) |v| return try self.anyTag(tag_, v),
 70        .null => {},
 71        else => @compileError("cannot encode type " ++ @typeName(T)),
 72    }
 73
 74    try self.length(self.buffer.data.len - start);
 75    try self.tag(merged_tag);
 76}
 77
 78/// Encode a tag.
 79pub fn tag(self: *Encoder, tag_: Tag) !void {
 80    const t = self.mergedTag(tag_);
 81    try t.encode(self.writer());
 82}
 83
 84fn mergedTag(self: *Encoder, tag_: Tag) Tag {
 85    var res = tag_;
 86    if (self.field_tag) |ft| {
 87        if (!ft.explicit) {
 88            res.number = @enumFromInt(ft.number);
 89            res.class = ft.class;
 90        }
 91    }
 92    return res;
 93}
 94
 95/// Encode a length.
 96pub fn length(self: *Encoder, len: usize) !void {
 97    const writer_ = self.writer();
 98    if (len < 128) {
 99        try writer_.writeInt(u8, @intCast(len), .big);
100        return;
101    }
102    inline for ([_]type{ u8, u16, u32 }) |T| {
103        if (len < std.math.maxInt(T)) {
104            try writer_.writeInt(T, @intCast(len), .big);
105            try writer_.writeInt(u8, @sizeOf(T) | 0x80, .big);
106            return;
107        }
108    }
109    return error.InvalidLength;
110}
111
112/// Encode a tag and length-prefixed bytes.
113pub fn tagBytes(self: *Encoder, tag_: Tag, bytes: []const u8) !void {
114    try self.buffer.prependSlice(bytes);
115    try self.length(bytes.len);
116    try self.tag(tag_);
117}
118
119/// Warning: This writer writes backwards. `fn print` will NOT work as expected.
120pub fn writer(self: *Encoder) ArrayListReverse.Writer {
121    return self.buffer.writer();
122}
123
124fn int(self: *Encoder, comptime T: type, value: T) !void {
125    const big = std.mem.nativeTo(T, value, .big);
126    const big_bytes = std.mem.asBytes(&big);
127
128    const bits_needed = @bitSizeOf(T) - @clz(value);
129    const needs_padding: u1 = if (value == 0)
130        1
131    else if (bits_needed > 8) brk: {
132        const RightShift = std.meta.Int(.unsigned, @bitSizeOf(@TypeOf(bits_needed)) - 1);
133        const right_shift: RightShift = @intCast(bits_needed - 9);
134        break :brk if (value >> right_shift == 0x1ff) 1 else 0;
135    } else 0;
136    const bytes_needed = try std.math.divCeil(usize, bits_needed, 8) + needs_padding;
137
138    const writer_ = self.writer();
139    for (0..bytes_needed - needs_padding) |i| try writer_.writeByte(big_bytes[big_bytes.len - i - 1]);
140    if (needs_padding == 1) try writer_.writeByte(0);
141}
142
143test int {
144    const allocator = std.testing.allocator;
145    var encoder = Encoder.init(allocator);
146    defer encoder.deinit();
147
148    try encoder.int(u8, 0);
149    try std.testing.expectEqualSlices(u8, &[_]u8{0}, encoder.buffer.data);
150
151    encoder.buffer.clearAndFree();
152    try encoder.int(u16, 0x00ff);
153    try std.testing.expectEqualSlices(u8, &[_]u8{0xff}, encoder.buffer.data);
154
155    encoder.buffer.clearAndFree();
156    try encoder.int(u32, 0xffff);
157    try std.testing.expectEqualSlices(u8, &[_]u8{ 0, 0xff, 0xff }, encoder.buffer.data);
158}
159
160const std = @import("std");
161const Oid = @import("../Oid.zig");
162const asn1 = @import("../../asn1.zig");
163const ArrayListReverse = @import("./ArrayListReverse.zig");
164const Tag = asn1.Tag;
165const FieldTag = asn1.FieldTag;
166const Encoder = @This();