Commit 8ae13f5

Anton Golub <antongolub@antongolub.com>
2024-11-08 14:15:37
feat(cli): allow overriding default script extension (#930)
* feat(cli): allow to override default script extension closes #929 * test: up size-limit * docs: mention `--ext` flag in man * refactor: handle extensions without dot
1 parent 3f164e0
man/zx.1
@@ -21,6 +21,8 @@ prefix all commands
 postfix all commands
 .SS --eval=<js>, -e
 evaluate script
+.SS --ext=<.mjs>
+default extension
 .SS --install, -i
 install dependencies
 .SS --repl
src/cli.ts
@@ -29,6 +29,8 @@ import { installDeps, parseDeps } from './deps.js'
 import { randomId } from './util.js'
 import { createRequire } from './vendor.js'
 
+const EXT = '.mjs'
+
 isMain() &&
   main().catch((err) => {
     if (err instanceof ProcessOutput) {
@@ -56,6 +58,7 @@ export function printUsage() {
    --postfix=<command>  postfix all commands
    --cwd=<path>         set current directory
    --eval=<js>, -e      evaluate script 
+   --ext=<.mjs>         default extension
    --install, -i        install dependencies
    --version, -v        print current zx version
    --help, -h           print help
@@ -67,7 +70,7 @@ export function printUsage() {
 }
 
 export const argv = minimist(process.argv.slice(2), {
-  string: ['shell', 'prefix', 'postfix', 'eval', 'cwd'],
+  string: ['shell', 'prefix', 'postfix', 'eval', 'cwd', 'ext'],
   boolean: [
     'version',
     'help',
@@ -83,6 +86,7 @@ export const argv = minimist(process.argv.slice(2), {
 
 export async function main() {
   await import('./globals.js')
+  argv.ext = normalizeExt(argv.ext)
   if (argv.cwd) $.cwd = argv.cwd
   if (argv.verbose) $.verbose = true
   if (argv.quiet) $.quiet = true
@@ -102,13 +106,13 @@ export async function main() {
     return
   }
   if (argv.eval) {
-    await runScript(argv.eval)
+    await runScript(argv.eval, argv.ext)
     return
   }
   const firstArg = argv._[0]
   updateArgv(argv._.slice(firstArg === undefined ? 0 : 1))
   if (!firstArg || firstArg === '-') {
-    const success = await scriptFromStdin()
+    const success = await scriptFromStdin(argv.ext)
     if (!success) {
       printUsage()
       process.exitCode = 1
@@ -116,7 +120,7 @@ export async function main() {
     return
   }
   if (/^https?:/.test(firstArg)) {
-    await scriptFromHttp(firstArg)
+    await scriptFromHttp(firstArg, argv.ext)
     return
   }
   const filepath = firstArg.startsWith('file:///')
@@ -125,12 +129,12 @@ export async function main() {
   await importPath(filepath)
 }
 
-export async function runScript(script: string) {
-  const filepath = path.join($.cwd ?? process.cwd(), `zx-${randomId()}.mjs`)
+export async function runScript(script: string, ext = EXT) {
+  const filepath = path.join($.cwd ?? process.cwd(), `zx-${randomId()}${ext}`)
   await writeAndImport(script, filepath)
 }
 
-export async function scriptFromStdin() {
+export async function scriptFromStdin(ext?: string) {
   let script = ''
   if (!process.stdin.isTTY) {
     process.stdin.setEncoding('utf8')
@@ -139,14 +143,14 @@ export async function scriptFromStdin() {
     }
 
     if (script.length > 0) {
-      await runScript(script)
+      await runScript(script, ext)
       return true
     }
   }
   return false
 }
 
-export async function scriptFromHttp(remote: string) {
+export async function scriptFromHttp(remote: string, _ext = EXT) {
   const res = await fetch(remote)
   if (!res.ok) {
     console.error(`Error: Can't get ${remote}`)
@@ -155,7 +159,7 @@ export async function scriptFromHttp(remote: string) {
   const script = await res.text()
   const pathname = new URL(remote).pathname
   const name = path.basename(pathname)
-  const ext = path.extname(pathname) || '.mjs'
+  const ext = path.extname(pathname) || _ext
   const cwd = $.cwd ?? process.cwd()
   const filepath = path.join(cwd, `${name}-${randomId()}${ext}`)
   await writeAndImport(script, filepath)
@@ -299,3 +303,9 @@ export function isMain(
 
   return false
 }
+
+export function normalizeExt(ext?: string) {
+  if (!ext) return
+  if (!/^\.?\w+(\.\w+)*$/.test(ext)) throw new Error(`Invalid extension ${ext}`)
+  return ext[0] === '.' ? ext : `.${ext}`
+}
test/cli.test.js
@@ -16,10 +16,13 @@ import assert from 'node:assert'
 import { test, describe, before, after } from 'node:test'
 import { fileURLToPath } from 'node:url'
 import '../build/globals.js'
-import { isMain } from '../build/cli.js'
+import { isMain, normalizeExt } from '../build/cli.js'
 
 const __filename = fileURLToPath(import.meta.url)
 const spawn = $.spawn
+const nodeMajor = +process.versions?.node?.split('.')[0]
+const test22 = nodeMajor >= 22 ? test : test.skip
+
 describe('cli', () => {
   // Helps detect unresolved ProcessPromise.
   before(() => {
@@ -144,6 +147,12 @@ describe('cli', () => {
     )
   })
 
+  test22('scripts from stdin with explicit extension', async () => {
+    const out =
+      await $`node --experimental-strip-types build/cli.js --ext='.ts' <<< 'const foo: string = "bar"; console.log(foo)'`
+    assert.match(out.stdout, /bar/)
+  })
+
   test('require() is working from stdin', async () => {
     const out =
       await $`node build/cli.js <<< 'console.log(require("./package.json").name)'`
@@ -258,4 +267,11 @@ describe('cli', () => {
       }
     })
   })
+
+  test('normalizeExt()', () => {
+    assert.equal(normalizeExt('.ts'), '.ts')
+    assert.equal(normalizeExt('ts'), '.ts')
+    assert.equal(normalizeExt(), undefined)
+    assert.throws(() => normalizeExt('.'))
+  })
 })
.size-limit.json
@@ -30,7 +30,7 @@
   {
     "name": "all",
     "path": "build/*",
-    "limit": "833 kB",
+    "limit": "833.6 kB",
     "brotli": false,
     "gzip": false
   }