Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,32 @@ func handleNoConfiguredVersion(shimName, runtimeName string, provider runtime.Sh
return fmt.Errorf("no version configured")
}

// getShimName returns the name of this shim binary
// getShimName returns the name of this shim binary based on os.Args[0].
func getShimName() string {
shimPath := os.Args[0]
shimName := filepath.Base(shimPath)
return shimNameFromPath(os.Args[0])
}

// Remove .exe extension on Windows
shimName = strings.TrimSuffix(shimName, ".exe")
// shimNameFromPath derives the shim name from an invocation path.
//
// On Windows, the filename's .exe extension is stripped case-insensitively.
// This matters because Windows command resolution via PATHEXT can surface
// uppercase extensions (e.g., Python's shutil.which returns "mmdc.EXE"
// when PATHEXT contains ".EXE"). A case-sensitive TrimSuffix would leave
// the uppercase extension attached, breaking every downstream lookup in
// the shim-map cache and the provider registry.
func shimNameFromPath(shimPath string) string {
// Split on both separators so Windows-style paths resolve correctly even
// when this runs on a host where filepath.Base ignores backslashes.
if i := strings.LastIndexAny(shimPath, `/\`); i >= 0 {
shimPath = shimPath[i+1:]
}

return shimName
// Strip .exe / .EXE / any mixed case on Windows-style paths.
if ext := filepath.Ext(shimPath); strings.EqualFold(ext, constants.ExtExe) {
shimPath = shimPath[:len(shimPath)-len(ext)]
}

return shimPath
}

// mapShimToRuntime maps a shim name to its runtime
Expand Down
56 changes: 56 additions & 0 deletions src/cmd/shim/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package main

import "testing"

func TestShimNameFromPath(t *testing.T) {
tests := []struct {
name string
shimPath string
want string
}{
{
name: "unix-style bare binary",
shimPath: "/home/user/.dtvem/shims/mmdc",
want: "mmdc",
},
{
name: "windows lowercase .exe",
shimPath: `C:\Users\calvin\.dtvem\shims\mmdc.exe`,
want: "mmdc",
},
{
name: "windows uppercase .EXE (PATHEXT-resolved)",
shimPath: `C:\Users\calvin\.dtvem\shims\mmdc.EXE`,
want: "mmdc",
},
{
name: "windows mixed case .Exe",
shimPath: `C:\Users\calvin\.dtvem\shims\mmdc.Exe`,
want: "mmdc",
},
{
name: "forward-slash path with uppercase extension",
shimPath: "C:/Users/calvin/.dtvem/shims/npm.EXE",
want: "npm",
},
{
name: "bare shim name without extension",
shimPath: "mmdc",
want: "mmdc",
},
{
name: "non-.exe extension is preserved (not stripped)",
shimPath: `C:\tools\something.bat`,
want: "something.bat",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shimNameFromPath(tt.shimPath)
if got != tt.want {
t.Errorf("shimNameFromPath(%q) = %q, want %q", tt.shimPath, got, tt.want)
}
})
}
}
28 changes: 28 additions & 0 deletions src/internal/shim/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,34 @@ func SaveShimMap(shimMap ShimMap) error {
return os.WriteFile(cachePath, data, 0644)
}

// MergeShimMap merges the given entries into the on-disk shim map and persists it.
//
// If the cache does not exist yet (first-time install), a new map is created.
// Existing entries with matching keys are overwritten. The in-memory cache is
// reset so subsequent LoadShimMap calls read the updated state from disk.
//
// This is the preferred path for install-time shim registration, where the
// caller knows only the shims it just created and wants to register them
// without rebuilding the entire map (which would require scanning every
// installed runtime — `Rehash` does that).
func MergeShimMap(entries ShimMap) error {
existing, err := loadShimMapFromDisk()
if err != nil || existing == nil {
// Cache missing, unreadable, or empty — start a fresh map.
existing = make(ShimMap, len(entries))
}

for shim, runtime := range entries {
existing[shim] = runtime
}

// Force the next LoadShimMap to re-read from disk so the merged entries
// are visible to any subsequent caller in the same process.
ResetShimMapCache()

return SaveShimMap(existing)
}

// LookupRuntime looks up the runtime for a given shim name using the cache.
// Returns the runtime name and true if found, or empty string and false if not.
func LookupRuntime(shimName string) (string, bool) {
Expand Down
176 changes: 176 additions & 0 deletions src/internal/shim/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,179 @@ func TestShimMapCacheOnlyLoadsOnce(t *testing.T) {
t.Errorf("Cache should not have reloaded - 'new' entry should not exist")
}
}

func TestMergeShimMap_CreatesWhenNoExistingCache(t *testing.T) {
tempDir := t.TempDir()

originalRoot := os.Getenv("DTVEM_ROOT")
_ = os.Setenv("DTVEM_ROOT", tempDir)
defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }()

config.ResetPathsCache()
defer config.ResetPathsCache()

ResetShimMapCache()
defer ResetShimMapCache()

// No cache directory pre-existing — MergeShimMap must create it from scratch.
entries := ShimMap{
"node": "node",
"npm": "node",
"npx": "node",
}

if err := MergeShimMap(entries); err != nil {
t.Fatalf("MergeShimMap returned error on fresh install: %v", err)
}

loaded, err := LoadShimMap()
if err != nil {
t.Fatalf("LoadShimMap after MergeShimMap failed: %v", err)
}

if len(loaded) != len(entries) {
t.Errorf("expected %d entries, got %d (%v)", len(entries), len(loaded), loaded)
}
for shim, runtime := range entries {
if got := loaded[shim]; got != runtime {
t.Errorf("entry %q: expected runtime %q, got %q", shim, runtime, got)
}
}
}

func TestMergeShimMap_MergesIntoExistingCache(t *testing.T) {
tempDir := t.TempDir()

originalRoot := os.Getenv("DTVEM_ROOT")
_ = os.Setenv("DTVEM_ROOT", tempDir)
defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }()

config.ResetPathsCache()
defer config.ResetPathsCache()

ResetShimMapCache()
defer ResetShimMapCache()

cacheDir := filepath.Join(tempDir, "cache")
if err := os.MkdirAll(cacheDir, 0755); err != nil {
t.Fatalf("Failed to create cache directory: %v", err)
}

// Seed an existing cache (simulates a prior install).
initial := ShimMap{
"python": "python",
"pip": "python",
}
if err := SaveShimMap(initial); err != nil {
t.Fatalf("seed SaveShimMap failed: %v", err)
}

// Merge in a disjoint set of entries (simulates installing a second runtime).
added := ShimMap{
"node": "node",
"npm": "node",
}
if err := MergeShimMap(added); err != nil {
t.Fatalf("MergeShimMap failed: %v", err)
}

loaded, err := LoadShimMap()
if err != nil {
t.Fatalf("LoadShimMap failed: %v", err)
}

// All four entries should now be present.
wantAll := ShimMap{
"python": "python",
"pip": "python",
"node": "node",
"npm": "node",
}
for shim, runtime := range wantAll {
if got := loaded[shim]; got != runtime {
t.Errorf("entry %q: expected runtime %q, got %q", shim, runtime, got)
}
}
}

func TestMergeShimMap_OverwritesExistingKeys(t *testing.T) {
tempDir := t.TempDir()

originalRoot := os.Getenv("DTVEM_ROOT")
_ = os.Setenv("DTVEM_ROOT", tempDir)
defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }()

config.ResetPathsCache()
defer config.ResetPathsCache()

ResetShimMapCache()
defer ResetShimMapCache()

cacheDir := filepath.Join(tempDir, "cache")
if err := os.MkdirAll(cacheDir, 0755); err != nil {
t.Fatalf("Failed to create cache directory: %v", err)
}

// Seed with a stale mapping (e.g., a shim that was previously attributed
// to the wrong runtime by some prior state).
stale := ShimMap{"corepack": "wrong"}
if err := SaveShimMap(stale); err != nil {
t.Fatalf("seed SaveShimMap failed: %v", err)
}

// Merge should overwrite with the correct runtime.
if err := MergeShimMap(ShimMap{"corepack": "node"}); err != nil {
t.Fatalf("MergeShimMap failed: %v", err)
}

loaded, err := LoadShimMap()
if err != nil {
t.Fatalf("LoadShimMap failed: %v", err)
}

if got := loaded["corepack"]; got != "node" {
t.Errorf("expected corepack remapped to node, got %q", got)
}
}

func TestMergeShimMap_ResetsInMemoryCache(t *testing.T) {
tempDir := t.TempDir()

originalRoot := os.Getenv("DTVEM_ROOT")
_ = os.Setenv("DTVEM_ROOT", tempDir)
defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }()

config.ResetPathsCache()
defer config.ResetPathsCache()

ResetShimMapCache()
defer ResetShimMapCache()

cacheDir := filepath.Join(tempDir, "cache")
if err := os.MkdirAll(cacheDir, 0755); err != nil {
t.Fatalf("Failed to create cache directory: %v", err)
}

// Prime the in-memory cache with an initial map.
if err := SaveShimMap(ShimMap{"node": "node"}); err != nil {
t.Fatalf("SaveShimMap failed: %v", err)
}
if _, err := LoadShimMap(); err != nil {
t.Fatalf("initial LoadShimMap failed: %v", err)
}

// Without ResetShimMapCache, the next Load would return the cached copy.
// MergeShimMap is supposed to reset it so callers see merged state.
if err := MergeShimMap(ShimMap{"npm": "node"}); err != nil {
t.Fatalf("MergeShimMap failed: %v", err)
}

loaded, err := LoadShimMap()
if err != nil {
t.Fatalf("post-merge LoadShimMap failed: %v", err)
}

if _, ok := loaded["npm"]; !ok {
t.Error("expected in-memory cache to be reset so the merged 'npm' entry is visible")
}
}
22 changes: 22 additions & 0 deletions src/internal/shim/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,28 @@ func (m *Manager) CreateShims(shimNames []string) error {
return nil
}

// CreateShimsForRuntime creates shim files for the given names and registers
// them in the shim map under the given runtime name.
//
// This is the preferred path for install-time shim creation (e.g., from a
// runtime provider's post-install hook). Bare CreateShims only writes the
// shim binaries to disk — it does not update the shim-map cache, which means
// subsequent shim invocations have to fall back to the provider registry
// lookup instead of the O(1) cache hit. Calling CreateShimsForRuntime keeps
// the shim files and the cache in sync from the moment they are created.
func (m *Manager) CreateShimsForRuntime(runtimeName string, shimNames []string) error {
if err := m.CreateShims(shimNames); err != nil {
return err
}

entries := make(ShimMap, len(shimNames))
for _, name := range shimNames {
entries[name] = runtimeName
}

return MergeShimMap(entries)
}

// RemoveShim removes a shim
func (m *Manager) RemoveShim(shimName string) error {
shimPath := config.ShimPath(shimName)
Expand Down
8 changes: 5 additions & 3 deletions src/runtimes/node/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ func (p *Provider) getDownloadURL(version string) (string, string, error) {
return dl.URL, archiveName, nil
}

// createShims creates shims for Node.js executables
// createShims creates shims for Node.js executables and registers them in the
// shim-map cache so subsequent shim invocations resolve via O(1) lookup rather
// than falling back to the provider registry.
func (p *Provider) createShims() error {
manager, err := shim.NewManager()
if err != nil {
Expand All @@ -162,8 +164,8 @@ func (p *Provider) createShims() error {
// Get the list of shims for Node.js
shimNames := shim.RuntimeShims("node")

// Create each shim
return manager.CreateShims(shimNames)
// Create each shim AND record them in the shim map cache
return manager.CreateShimsForRuntime("node", shimNames)
}

// Uninstall removes an installed version
Expand Down
8 changes: 5 additions & 3 deletions src/runtimes/python/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ func (p *Provider) getDownloadURL(version string) (string, string, error) {
return dl.URL, archiveName, nil
}

// createShims creates shims for Python executables
// createShims creates shims for Python executables and registers them in the
// shim-map cache so subsequent shim invocations resolve via O(1) lookup rather
// than falling back to the provider registry.
func (p *Provider) createShims() error {
manager, err := shim.NewManager()
if err != nil {
Expand All @@ -219,8 +221,8 @@ func (p *Provider) createShims() error {
// Get the list of shims for Python
shimNames := shim.RuntimeShims("python")

// Create each shim
return manager.CreateShims(shimNames)
// Create each shim AND record them in the shim map cache
return manager.CreateShimsForRuntime("python", shimNames)
}

// installPip ensures pip is properly installed with working executables.
Expand Down
Loading
Loading