Commit 61482be153

Vexu <git@vexu.eu>
2019-12-19 19:54:39
translate-c-2 improve macro fn ptr caller
1 parent f837c7c
Changed files (4)
lib
src-self-hosted
test
lib/std/zig/ast.zig
@@ -1436,8 +1436,8 @@ pub const Node = struct {
             AssignMod,
             AssignAdd,
             AssignAddWrap,
-            AssignMult,
-            AssignMultWrap,
+            AssignMul,
+            AssignMulWrap,
             BangEqual,
             BitAnd,
             BitOr,
@@ -1495,8 +1495,8 @@ pub const Node = struct {
                 Op.AssignMod,
                 Op.AssignAdd,
                 Op.AssignAddWrap,
-                Op.AssignMult,
-                Op.AssignMultWrap,
+                Op.AssignMul,
+                Op.AssignMulWrap,
                 Op.BangEqual,
                 Op.BitAnd,
                 Op.BitOr,
lib/std/zig/parse.zig
@@ -1981,7 +1981,7 @@ fn parseAssignOp(arena: *Allocator, it: *TokenIterator, tree: *Tree) !?*Node {
 
     const token = nextToken(it);
     const op = switch (token.ptr.id) {
-        .AsteriskEqual => Op{ .AssignMult = {} },
+        .AsteriskEqual => Op{ .AssignMul = {} },
         .SlashEqual => Op{ .AssignDiv = {} },
         .PercentEqual => Op{ .AssignMod = {} },
         .PlusEqual => Op{ .AssignAdd = {} },
@@ -1991,7 +1991,7 @@ fn parseAssignOp(arena: *Allocator, it: *TokenIterator, tree: *Tree) !?*Node {
         .AmpersandEqual => Op{ .AssignBitAnd = {} },
         .CaretEqual => Op{ .AssignBitXor = {} },
         .PipeEqual => Op{ .AssignBitOr = {} },
-        .AsteriskPercentEqual => Op{ .AssignMultWrap = {} },
+        .AsteriskPercentEqual => Op{ .AssignMulWrap = {} },
         .PlusPercentEqual => Op{ .AssignAddWrap = {} },
         .MinusPercentEqual => Op{ .AssignSubWrap = {} },
         .Equal => Op{ .Assign = {} },
src-self-hosted/translate_c.zig
@@ -8,6 +8,7 @@ const Token = std.zig.Token;
 usingnamespace @import("clang.zig");
 const ctok = @import("c_tokenizer.zig");
 const CToken = ctok.CToken;
+const mem = std.mem;
 
 const CallingConvention = std.builtin.TypeInfo.CallingConvention;
 
@@ -83,7 +84,7 @@ const Scope = struct {
         fn getAlias(scope: *Block, name: []const u8) ?[]const u8 {
             var it = scope.variables.iterator(0);
             while (it.next()) |p| {
-                if (std.mem.eql(u8, p.name, name))
+                if (mem.eql(u8, p.name, name))
                     return p.alias;
             }
             return scope.base.parent.?.getAlias(name);
@@ -92,7 +93,7 @@ const Scope = struct {
         fn contains(scope: *Block, name: []const u8) bool {
             var it = scope.variables.iterator(0);
             while (it.next()) |p| {
-                if (std.mem.eql(u8, p.name, name))
+                if (mem.eql(u8, p.name, name))
                     return true;
             }
             return scope.base.parent.?.contains(name);
@@ -137,7 +138,7 @@ const Scope = struct {
         fn getAlias(scope: *FnDef, name: []const u8) ?[]const u8 {
             var it = scope.params.iterator(0);
             while (it.next()) |p| {
-                if (std.mem.eql(u8, p.name, name))
+                if (mem.eql(u8, p.name, name))
                     return p.alias;
             }
             return scope.base.parent.?.getAlias(name);
@@ -146,7 +147,7 @@ const Scope = struct {
         fn contains(scope: *FnDef, name: []const u8) bool {
             var it = scope.params.iterator(0);
             while (it.next()) |p| {
-                if (std.mem.eql(u8, p.name, name))
+                if (mem.eql(u8, p.name, name))
                     return true;
             }
             return scope.base.parent.?.contains(name);
@@ -233,13 +234,13 @@ const Context = struct {
         return c.mangle_count;
     }
 
-    fn a(c: *Context) *std.mem.Allocator {
+    fn a(c: *Context) *mem.Allocator {
         return &c.tree.arena_allocator.allocator;
     }
 
     /// Convert a null-terminated C string to a slice allocated in the arena
     fn str(c: *Context, s: [*:0]const u8) ![]u8 {
-        return std.mem.dupe(c.a(), u8, std.mem.toSliceConst(u8, s));
+        return mem.dupe(c.a(), u8, mem.toSliceConst(u8, s));
     }
 
     /// Convert a clang source location to a file:line:column string
@@ -255,7 +256,7 @@ const Context = struct {
 };
 
 pub fn translate(
-    backing_allocator: *std.mem.Allocator,
+    backing_allocator: *mem.Allocator,
     args_begin: [*]?[*]const u8,
     args_end: [*]?[*]const u8,
     errors: *[]ClangErrMsg,
@@ -540,29 +541,29 @@ fn transTypeDef(c: *Context, typedef_decl: *const ZigClangTypedefNameDecl) Error
 
     const typedef_name = try c.str(ZigClangDecl_getName_bytes_begin(@ptrCast(*const ZigClangDecl, typedef_decl)));
 
-    if (std.mem.eql(u8, typedef_name, "uint8_t"))
+    if (mem.eql(u8, typedef_name, "uint8_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "u8")
-    else if (std.mem.eql(u8, typedef_name, "int8_t"))
+    else if (mem.eql(u8, typedef_name, "int8_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "i8")
-    else if (std.mem.eql(u8, typedef_name, "uint16_t"))
+    else if (mem.eql(u8, typedef_name, "uint16_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "u16")
-    else if (std.mem.eql(u8, typedef_name, "int16_t"))
+    else if (mem.eql(u8, typedef_name, "int16_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "i16")
-    else if (std.mem.eql(u8, typedef_name, "uint32_t"))
+    else if (mem.eql(u8, typedef_name, "uint32_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "u32")
-    else if (std.mem.eql(u8, typedef_name, "int32_t"))
+    else if (mem.eql(u8, typedef_name, "int32_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "i32")
-    else if (std.mem.eql(u8, typedef_name, "uint64_t"))
+    else if (mem.eql(u8, typedef_name, "uint64_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "u64")
-    else if (std.mem.eql(u8, typedef_name, "int64_t"))
+    else if (mem.eql(u8, typedef_name, "int64_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "i64")
-    else if (std.mem.eql(u8, typedef_name, "intptr_t"))
+    else if (mem.eql(u8, typedef_name, "intptr_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "isize")
-    else if (std.mem.eql(u8, typedef_name, "uintptr_t"))
+    else if (mem.eql(u8, typedef_name, "uintptr_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "usize")
-    else if (std.mem.eql(u8, typedef_name, "ssize_t"))
+    else if (mem.eql(u8, typedef_name, "ssize_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "isize")
-    else if (std.mem.eql(u8, typedef_name, "size_t"))
+    else if (mem.eql(u8, typedef_name, "size_t"))
         return transTypeDefAsBuiltin(c, typedef_decl, "usize");
 
     _ = try c.decl_table.put(@ptrToInt(ZigClangTypedefNameDecl_getCanonicalDecl(typedef_decl)), typedef_name);
@@ -704,7 +705,7 @@ fn transEnumDecl(c: *Context, enum_decl: *const ZigClangEnumDecl) Error!?*ast.No
 
     const name = try std.fmt.allocPrint(c.a(), "enum_{}", .{bare_name});
     _ = try c.decl_table.put(@ptrToInt(ZigClangEnumDecl_getCanonicalDecl(enum_decl)), name);
-    const node = try transCreateNodeVarDecl(c, true, true, name);
+    const node = try transCreateNodeVarDecl(c, !is_unnamed, true, name);
     node.eq_token = try appendToken(c, .Equal, "=");
 
     node.init_node = if (ZigClangEnumDecl_getDefinition(enum_decl)) |enum_def| blk: {
@@ -762,7 +763,7 @@ fn transEnumDecl(c: *Context, enum_decl: *const ZigClangEnumDecl) Error!?*ast.No
 
             const enum_val_name = try c.str(ZigClangDecl_getName_bytes_begin(@ptrCast(*const ZigClangDecl, enum_const)));
 
-            const field_name = if (!is_unnamed and std.mem.startsWith(u8, enum_val_name, bare_name))
+            const field_name = if (!is_unnamed and mem.startsWith(u8, enum_val_name, bare_name))
                 enum_val_name[bare_name.len..]
             else
                 enum_val_name;
@@ -1448,7 +1449,7 @@ fn writeEscapedString(buf: []u8, s: []const u8) void {
     var i: usize = 0;
     for (s) |c| {
         const escaped = escapeChar(c, &char_buf);
-        std.mem.copy(u8, buf[i..], escaped);
+        mem.copy(u8, buf[i..], escaped);
         i += escaped.len;
     }
 }
@@ -1537,16 +1538,6 @@ fn transCCast(
         builtin_node.rparen_token = try appendToken(rp.c, .RParen, ")");
         return &builtin_node.base;
     }
-    // TODO
-    // if (ZigClangQualType_getTypeClass(dst_type) == .Enum and
-    //     ZigClangQualType_getTypeClass(src_type) != .Enum) {
-    //     const builtin_node = try transCreateNodeBuiltinFnCall(rp.c, "@intToEnum");
-    //     try builtin_node.params.push(try transQualType(rp, dst_type, loc));
-    //     _ = try appendToken(rp.c, .Comma, ",");
-    //     try builtin_node.params.push(expr);
-    //     builtin_node.rparen_token = try appendToken(rp.c, .RParen, ")");
-    //     return &builtin_node.base;
-    // }
     // TODO: maybe widen to increase size
     // TODO: maybe bitcast to change sign
     // TODO: maybe truncate to reduce size
@@ -2364,9 +2355,9 @@ fn transCreatePostCrement(
 fn transCompoundAssignOperator(rp: RestorePoint, scope: *Scope, stmt: *const ZigClangCompoundAssignOperator, used: ResultUsed) TransError!*ast.Node {
     switch (ZigClangCompoundAssignOperator_getOpcode(stmt)) {
         .MulAssign => if (qualTypeHaswrappingOverflow(ZigClangCompoundAssignOperator_getType(stmt)))
-            return transCreateCompoundAssign(rp, scope, stmt, .AssignMultWrap, .AsteriskPercentEqual, "*%=", .MultWrap, .AsteriskPercent, "*%", used)
+            return transCreateCompoundAssign(rp, scope, stmt, .AssignMulWrap, .AsteriskPercentEqual, "*%=", .MultWrap, .AsteriskPercent, "*%", used)
         else
-            return transCreateCompoundAssign(rp, scope, stmt, .AssignMult, .AsteriskEqual, "*=", .Mult, .Asterisk, "*", used),
+            return transCreateCompoundAssign(rp, scope, stmt, .AssignMul, .AsteriskEqual, "*=", .Mult, .Asterisk, "*", used),
         .AddAssign => if (qualTypeHaswrappingOverflow(ZigClangCompoundAssignOperator_getType(stmt)))
             return transCreateCompoundAssign(rp, scope, stmt, .AssignAddWrap, .PlusPercentEqual, "+%=", .AddWrap, .PlusPercent, "+%", used)
         else
@@ -2645,13 +2636,13 @@ fn qualTypeIntBitWidth(rp: RestorePoint, qt: ZigClangQualType, source_loc: ZigCl
             const typedef_decl = ZigClangTypedefType_getDecl(typedef_ty);
             const type_name = try rp.c.str(ZigClangDecl_getName_bytes_begin(@ptrCast(*const ZigClangDecl, typedef_decl)));
 
-            if (std.mem.eql(u8, type_name, "uint8_t") or std.mem.eql(u8, type_name, "int8_t")) {
+            if (mem.eql(u8, type_name, "uint8_t") or mem.eql(u8, type_name, "int8_t")) {
                 return 8;
-            } else if (std.mem.eql(u8, type_name, "uint16_t") or std.mem.eql(u8, type_name, "int16_t")) {
+            } else if (mem.eql(u8, type_name, "uint16_t") or mem.eql(u8, type_name, "int16_t")) {
                 return 16;
-            } else if (std.mem.eql(u8, type_name, "uint32_t") or std.mem.eql(u8, type_name, "int32_t")) {
+            } else if (mem.eql(u8, type_name, "uint32_t") or mem.eql(u8, type_name, "int32_t")) {
                 return 32;
-            } else if (std.mem.eql(u8, type_name, "uint64_t") or std.mem.eql(u8, type_name, "int64_t")) {
+            } else if (mem.eql(u8, type_name, "uint64_t") or mem.eql(u8, type_name, "int64_t")) {
                 return 64;
             } else {
                 return 0;
@@ -3176,7 +3167,7 @@ fn transCreateNodeOpaqueType(c: *Context) !*ast.Node {
     return &call_node.base;
 }
 
-fn transCreateNodeMacroFn(c: *Context, name: []const u8, ref: *ast.Node, proto_alias_node: *ast.Node) !*ast.Node {
+fn transCreateNodeMacroFn(c: *Context, name: []const u8, ref: *ast.Node, proto_alias: *ast.Node.FnProto) !*ast.Node {
     const scope = &c.global_scope.base;
 
     const pub_tok = try appendToken(c, .Keyword_pub, "pub");
@@ -3185,8 +3176,6 @@ fn transCreateNodeMacroFn(c: *Context, name: []const u8, ref: *ast.Node, proto_a
     const name_tok = try appendIdentifier(c, name);
     _ = try appendToken(c, .LParen, "(");
 
-    const proto_alias = proto_alias_node.cast(ast.Node.FnProto).?;
-
     var fn_params = ast.Node.FnProto.ParamList.init(c.a());
     var it = proto_alias.params.iterator(0);
     while (it.next()) |pn| {
@@ -3948,27 +3937,27 @@ fn isZigPrimitiveType(name: []const u8) bool {
         return true;
     }
     // void is invalid in c so it doesn't need to be checked.
-    return std.mem.eql(u8, name, "comptime_float") or
-        std.mem.eql(u8, name, "comptime_int") or
-        std.mem.eql(u8, name, "bool") or
-        std.mem.eql(u8, name, "isize") or
-        std.mem.eql(u8, name, "usize") or
-        std.mem.eql(u8, name, "f16") or
-        std.mem.eql(u8, name, "f32") or
-        std.mem.eql(u8, name, "f64") or
-        std.mem.eql(u8, name, "f128") or
-        std.mem.eql(u8, name, "c_longdouble") or
-        std.mem.eql(u8, name, "noreturn") or
-        std.mem.eql(u8, name, "type") or
-        std.mem.eql(u8, name, "anyerror") or
-        std.mem.eql(u8, name, "c_short") or
-        std.mem.eql(u8, name, "c_ushort") or
-        std.mem.eql(u8, name, "c_int") or
-        std.mem.eql(u8, name, "c_uint") or
-        std.mem.eql(u8, name, "c_long") or
-        std.mem.eql(u8, name, "c_ulong") or
-        std.mem.eql(u8, name, "c_longlong") or
-        std.mem.eql(u8, name, "c_ulonglong");
+    return mem.eql(u8, name, "comptime_float") or
+        mem.eql(u8, name, "comptime_int") or
+        mem.eql(u8, name, "bool") or
+        mem.eql(u8, name, "isize") or
+        mem.eql(u8, name, "usize") or
+        mem.eql(u8, name, "f16") or
+        mem.eql(u8, name, "f32") or
+        mem.eql(u8, name, "f64") or
+        mem.eql(u8, name, "f128") or
+        mem.eql(u8, name, "c_longdouble") or
+        mem.eql(u8, name, "noreturn") or
+        mem.eql(u8, name, "type") or
+        mem.eql(u8, name, "anyerror") or
+        mem.eql(u8, name, "c_short") or
+        mem.eql(u8, name, "c_ushort") or
+        mem.eql(u8, name, "c_int") or
+        mem.eql(u8, name, "c_uint") or
+        mem.eql(u8, name, "c_long") or
+        mem.eql(u8, name, "c_ulong") or
+        mem.eql(u8, name, "c_longlong") or
+        mem.eql(u8, name, "c_ulonglong");
 }
 
 fn isValidZigIdentifier(name: []const u8) bool {
@@ -4039,13 +4028,13 @@ fn transPreprocessorEntities(c: *Context, unit: *ZigClangASTUnit) Error!void {
 
                 var tok_it = tok_list.iterator(0);
                 const first_tok = tok_it.next().?;
-                assert(first_tok.id == .Identifier and std.mem.eql(u8, first_tok.bytes, name));
+                assert(first_tok.id == .Identifier and mem.eql(u8, first_tok.bytes, name));
                 const next = tok_it.peek().?;
                 switch (next.id) {
                     .Identifier => {
                         // if it equals itself, ignore. for example, from stdio.h:
                         // #define stdin stdin
-                        if (std.mem.eql(u8, checked_name, next.bytes)) {
+                        if (mem.eql(u8, checked_name, next.bytes)) {
                             continue;
                         }
                     },
@@ -4493,21 +4482,71 @@ fn tokenSlice(c: *Context, token: ast.TokenIndex) []const u8 {
     return c.source_buffer.toSliceConst()[tok.start..tok.end];
 }
 
-fn getFnDecl(c: *Context, ref: *ast.Node) ?*ast.Node {
-    const init = if (ref.cast(ast.Node.VarDecl)) |v| v.init_node.? else return null;
-    const name = if (init.cast(ast.Node.Identifier)) |id|
-        tokenSlice(c, id.token)
-    else
-        return null;
-    // TODO a.b.c
-    if (c.global_scope.sym_table.get(name)) |kv| {
-        if (kv.value.cast(ast.Node.VarDecl)) |val| {
-            if (val.type_node) |type_node| {
-                if (type_node.cast(ast.Node.PrefixOp)) |casted| {
-                    if (casted.rhs.id == .FnProto) {
-                        return casted.rhs;
+fn getContainer(c: *Context, node: *ast.Node) ?*ast.Node {
+    if (node.id == .ContainerDecl) {
+        return node;
+    } else if (node.id == .PrefixOp) {
+        return node;
+    } else if (node.cast(ast.Node.Identifier)) |ident| {
+        if (c.global_scope.sym_table.get(tokenSlice(c, ident.token))) |kv| {
+            if (kv.value.cast(ast.Node.VarDecl)) |var_decl|
+                return getContainer(c, var_decl.init_node.?);
+        }
+    } else if (node.cast(ast.Node.InfixOp)) |infix| {
+        if (infix.op != .Period)
+            return null;
+        if (getContainerTypeOf(c, infix.lhs)) |ty_node| {
+            if (ty_node.cast(ast.Node.ContainerDecl)) |container| {
+                var it = container.fields_and_decls.iterator(0);
+                while (it.next()) |field_ref| {
+                    const field = field_ref.*.cast(ast.Node.ContainerField).?;
+                    const ident = infix.rhs.cast(ast.Node.Identifier).?;
+                    if (mem.eql(u8, tokenSlice(c, field.name_token), tokenSlice(c, ident.token))) {
+                        return getContainer(c, field.type_expr.?);
+                    }
+                }
+            }
+        }
+    }
+    return null;
+}
+
+fn getContainerTypeOf(c: *Context, ref: *ast.Node) ?*ast.Node {
+    if (ref.cast(ast.Node.Identifier)) |ident| {
+        if (c.global_scope.sym_table.get(tokenSlice(c, ident.token))) |kv| {
+            if (kv.value.cast(ast.Node.VarDecl)) |var_decl| {
+                if (var_decl.type_node) |ty|
+                    return getContainer(c, ty);
+            }
+        }
+    } else if (ref.cast(ast.Node.InfixOp)) |infix| {
+        if (infix.op != .Period)
+            return null;
+        if (getContainerTypeOf(c, infix.lhs)) |ty_node| {
+            if (ty_node.cast(ast.Node.ContainerDecl)) |container| {
+                var it = container.fields_and_decls.iterator(0);
+                while (it.next()) |field_ref| {
+                    const field = field_ref.*.cast(ast.Node.ContainerField).?;
+                    const ident = infix.rhs.cast(ast.Node.Identifier).?;
+                    if (mem.eql(u8, tokenSlice(c, field.name_token), tokenSlice(c, ident.token))) {
+                        return getContainer(c, field.type_expr.?);
                     }
                 }
+            } else
+                return ty_node;
+        }
+    }
+    return null;
+}
+
+fn getFnProto(c: *Context, ref: *ast.Node) ?*ast.Node.FnProto {
+    const init = if (ref.cast(ast.Node.VarDecl)) |v| v.init_node.? else return null;
+    if (getContainerTypeOf(c, init)) |ty_node| {
+        if (ty_node.cast(ast.Node.PrefixOp)) |prefix| {
+            if (prefix.op == .OptionalType) {
+                if (prefix.rhs.cast(ast.Node.FnProto)) |fn_proto| {
+                    return fn_proto;
+                }
             }
         }
     }
@@ -4517,7 +4556,7 @@ fn getFnDecl(c: *Context, ref: *ast.Node) ?*ast.Node {
 fn addMacros(c: *Context) !void {
     var macro_it = c.global_scope.macro_table.iterator();
     while (macro_it.next()) |kv| {
-        if (getFnDecl(c, kv.value)) |proto_node| {
+        if (getFnProto(c, kv.value)) |proto_node| {
             // If a macro aliases a global variable which is a function pointer, we conclude that
             // the macro is intended to represent a function that assumes the function pointer
             // variable is non-null and calls it.
test/translate_c.zig
@@ -887,7 +887,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\pub const a = enum_unnamed_1.a;
         \\pub const b = enum_unnamed_1.b;
         \\pub const c = enum_unnamed_1.c;
-        \\pub const enum_unnamed_1 = extern enum {
+        \\const enum_unnamed_1 = extern enum {
         \\    a,
         \\    b,
         \\    c,
@@ -896,7 +896,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\pub const e = enum_unnamed_2.e;
         \\pub const f = enum_unnamed_2.f;
         \\pub const g = enum_unnamed_2.g;
-        \\pub const enum_unnamed_2 = extern enum {
+        \\const enum_unnamed_2 = extern enum {
         \\    e = 0,
         \\    f = 4,
         \\    g = 5,
@@ -905,7 +905,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\pub const i = enum_unnamed_3.i;
         \\pub const j = enum_unnamed_3.j;
         \\pub const k = enum_unnamed_3.k;
-        \\pub const enum_unnamed_3 = extern enum {
+        \\const enum_unnamed_3 = extern enum {
         \\    i,
         \\    j,
         \\    k,
@@ -1027,10 +1027,10 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\pub extern var glProcs: union_OpenGLProcs;
     ,
         \\pub const glClearPFN = PFNGLCLEARPROC;
- // , // TODO
-    //     \\pub inline fn glClearUnion(arg_1: GLbitfield) void {
-    //     \\    return glProcs.gl.Clear.?(arg_1);
-    //     \\}
+    ,
+        \\pub inline fn glClearUnion(arg_2: GLbitfield) void {
+        \\    return glProcs.gl.Clear.?(arg_2);
+        \\}
         ,
         \\pub const OpenGLProcs = union_OpenGLProcs;
     });
@@ -1348,7 +1348,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
     , &[_][]const u8{
         \\pub const One = enum_unnamed_1.One;
         \\pub const Two = enum_unnamed_1.Two;
-        \\pub const enum_unnamed_1 = extern enum {
+        \\const enum_unnamed_1 = extern enum {
         \\    One,
         \\    Two,
         \\};