diff --git a/src/cmd/shim/main.go b/src/cmd/shim/main.go index 880e3d3..63ac615 100644 --- a/src/cmd/shim/main.go +++ b/src/cmd/shim/main.go @@ -155,15 +155,32 @@ func handleNoConfiguredVersion(shimName, runtimeName string, provider runtime.Sh return fmt.Errorf("no version configured") } -// getShimName returns the name of this shim binary +// getShimName returns the name of this shim binary based on os.Args[0]. func getShimName() string { - shimPath := os.Args[0] - shimName := filepath.Base(shimPath) + return shimNameFromPath(os.Args[0]) +} - // Remove .exe extension on Windows - shimName = strings.TrimSuffix(shimName, ".exe") +// shimNameFromPath derives the shim name from an invocation path. +// +// On Windows, the filename's .exe extension is stripped case-insensitively. +// This matters because Windows command resolution via PATHEXT can surface +// uppercase extensions (e.g., Python's shutil.which returns "mmdc.EXE" +// when PATHEXT contains ".EXE"). A case-sensitive TrimSuffix would leave +// the uppercase extension attached, breaking every downstream lookup in +// the shim-map cache and the provider registry. +func shimNameFromPath(shimPath string) string { + // Split on both separators so Windows-style paths resolve correctly even + // when this runs on a host where filepath.Base ignores backslashes. + if i := strings.LastIndexAny(shimPath, `/\`); i >= 0 { + shimPath = shimPath[i+1:] + } - return shimName + // Strip .exe / .EXE / any mixed case on Windows-style paths. + if ext := filepath.Ext(shimPath); strings.EqualFold(ext, constants.ExtExe) { + shimPath = shimPath[:len(shimPath)-len(ext)] + } + + return shimPath } // mapShimToRuntime maps a shim name to its runtime diff --git a/src/cmd/shim/main_test.go b/src/cmd/shim/main_test.go new file mode 100644 index 0000000..2667f6d --- /dev/null +++ b/src/cmd/shim/main_test.go @@ -0,0 +1,56 @@ +package main + +import "testing" + +func TestShimNameFromPath(t *testing.T) { + tests := []struct { + name string + shimPath string + want string + }{ + { + name: "unix-style bare binary", + shimPath: "/home/user/.dtvem/shims/mmdc", + want: "mmdc", + }, + { + name: "windows lowercase .exe", + shimPath: `C:\Users\calvin\.dtvem\shims\mmdc.exe`, + want: "mmdc", + }, + { + name: "windows uppercase .EXE (PATHEXT-resolved)", + shimPath: `C:\Users\calvin\.dtvem\shims\mmdc.EXE`, + want: "mmdc", + }, + { + name: "windows mixed case .Exe", + shimPath: `C:\Users\calvin\.dtvem\shims\mmdc.Exe`, + want: "mmdc", + }, + { + name: "forward-slash path with uppercase extension", + shimPath: "C:/Users/calvin/.dtvem/shims/npm.EXE", + want: "npm", + }, + { + name: "bare shim name without extension", + shimPath: "mmdc", + want: "mmdc", + }, + { + name: "non-.exe extension is preserved (not stripped)", + shimPath: `C:\tools\something.bat`, + want: "something.bat", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shimNameFromPath(tt.shimPath) + if got != tt.want { + t.Errorf("shimNameFromPath(%q) = %q, want %q", tt.shimPath, got, tt.want) + } + }) + } +} diff --git a/src/internal/shim/cache.go b/src/internal/shim/cache.go index fc5dd3c..a7f9d24 100644 --- a/src/internal/shim/cache.go +++ b/src/internal/shim/cache.go @@ -66,6 +66,34 @@ func SaveShimMap(shimMap ShimMap) error { return os.WriteFile(cachePath, data, 0644) } +// MergeShimMap merges the given entries into the on-disk shim map and persists it. +// +// If the cache does not exist yet (first-time install), a new map is created. +// Existing entries with matching keys are overwritten. The in-memory cache is +// reset so subsequent LoadShimMap calls read the updated state from disk. +// +// This is the preferred path for install-time shim registration, where the +// caller knows only the shims it just created and wants to register them +// without rebuilding the entire map (which would require scanning every +// installed runtime — `Rehash` does that). +func MergeShimMap(entries ShimMap) error { + existing, err := loadShimMapFromDisk() + if err != nil || existing == nil { + // Cache missing, unreadable, or empty — start a fresh map. + existing = make(ShimMap, len(entries)) + } + + for shim, runtime := range entries { + existing[shim] = runtime + } + + // Force the next LoadShimMap to re-read from disk so the merged entries + // are visible to any subsequent caller in the same process. + ResetShimMapCache() + + return SaveShimMap(existing) +} + // LookupRuntime looks up the runtime for a given shim name using the cache. // Returns the runtime name and true if found, or empty string and false if not. func LookupRuntime(shimName string) (string, bool) { diff --git a/src/internal/shim/cache_test.go b/src/internal/shim/cache_test.go index 0341e1b..1974439 100644 --- a/src/internal/shim/cache_test.go +++ b/src/internal/shim/cache_test.go @@ -218,3 +218,179 @@ func TestShimMapCacheOnlyLoadsOnce(t *testing.T) { t.Errorf("Cache should not have reloaded - 'new' entry should not exist") } } + +func TestMergeShimMap_CreatesWhenNoExistingCache(t *testing.T) { + tempDir := t.TempDir() + + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + config.ResetPathsCache() + defer config.ResetPathsCache() + + ResetShimMapCache() + defer ResetShimMapCache() + + // No cache directory pre-existing — MergeShimMap must create it from scratch. + entries := ShimMap{ + "node": "node", + "npm": "node", + "npx": "node", + } + + if err := MergeShimMap(entries); err != nil { + t.Fatalf("MergeShimMap returned error on fresh install: %v", err) + } + + loaded, err := LoadShimMap() + if err != nil { + t.Fatalf("LoadShimMap after MergeShimMap failed: %v", err) + } + + if len(loaded) != len(entries) { + t.Errorf("expected %d entries, got %d (%v)", len(entries), len(loaded), loaded) + } + for shim, runtime := range entries { + if got := loaded[shim]; got != runtime { + t.Errorf("entry %q: expected runtime %q, got %q", shim, runtime, got) + } + } +} + +func TestMergeShimMap_MergesIntoExistingCache(t *testing.T) { + tempDir := t.TempDir() + + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + config.ResetPathsCache() + defer config.ResetPathsCache() + + ResetShimMapCache() + defer ResetShimMapCache() + + cacheDir := filepath.Join(tempDir, "cache") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Seed an existing cache (simulates a prior install). + initial := ShimMap{ + "python": "python", + "pip": "python", + } + if err := SaveShimMap(initial); err != nil { + t.Fatalf("seed SaveShimMap failed: %v", err) + } + + // Merge in a disjoint set of entries (simulates installing a second runtime). + added := ShimMap{ + "node": "node", + "npm": "node", + } + if err := MergeShimMap(added); err != nil { + t.Fatalf("MergeShimMap failed: %v", err) + } + + loaded, err := LoadShimMap() + if err != nil { + t.Fatalf("LoadShimMap failed: %v", err) + } + + // All four entries should now be present. + wantAll := ShimMap{ + "python": "python", + "pip": "python", + "node": "node", + "npm": "node", + } + for shim, runtime := range wantAll { + if got := loaded[shim]; got != runtime { + t.Errorf("entry %q: expected runtime %q, got %q", shim, runtime, got) + } + } +} + +func TestMergeShimMap_OverwritesExistingKeys(t *testing.T) { + tempDir := t.TempDir() + + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + config.ResetPathsCache() + defer config.ResetPathsCache() + + ResetShimMapCache() + defer ResetShimMapCache() + + cacheDir := filepath.Join(tempDir, "cache") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Seed with a stale mapping (e.g., a shim that was previously attributed + // to the wrong runtime by some prior state). + stale := ShimMap{"corepack": "wrong"} + if err := SaveShimMap(stale); err != nil { + t.Fatalf("seed SaveShimMap failed: %v", err) + } + + // Merge should overwrite with the correct runtime. + if err := MergeShimMap(ShimMap{"corepack": "node"}); err != nil { + t.Fatalf("MergeShimMap failed: %v", err) + } + + loaded, err := LoadShimMap() + if err != nil { + t.Fatalf("LoadShimMap failed: %v", err) + } + + if got := loaded["corepack"]; got != "node" { + t.Errorf("expected corepack remapped to node, got %q", got) + } +} + +func TestMergeShimMap_ResetsInMemoryCache(t *testing.T) { + tempDir := t.TempDir() + + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + config.ResetPathsCache() + defer config.ResetPathsCache() + + ResetShimMapCache() + defer ResetShimMapCache() + + cacheDir := filepath.Join(tempDir, "cache") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Prime the in-memory cache with an initial map. + if err := SaveShimMap(ShimMap{"node": "node"}); err != nil { + t.Fatalf("SaveShimMap failed: %v", err) + } + if _, err := LoadShimMap(); err != nil { + t.Fatalf("initial LoadShimMap failed: %v", err) + } + + // Without ResetShimMapCache, the next Load would return the cached copy. + // MergeShimMap is supposed to reset it so callers see merged state. + if err := MergeShimMap(ShimMap{"npm": "node"}); err != nil { + t.Fatalf("MergeShimMap failed: %v", err) + } + + loaded, err := LoadShimMap() + if err != nil { + t.Fatalf("post-merge LoadShimMap failed: %v", err) + } + + if _, ok := loaded["npm"]; !ok { + t.Error("expected in-memory cache to be reset so the merged 'npm' entry is visible") + } +} diff --git a/src/internal/shim/manager.go b/src/internal/shim/manager.go index dcd8237..f4ff348 100644 --- a/src/internal/shim/manager.go +++ b/src/internal/shim/manager.go @@ -108,6 +108,28 @@ func (m *Manager) CreateShims(shimNames []string) error { return nil } +// CreateShimsForRuntime creates shim files for the given names and registers +// them in the shim map under the given runtime name. +// +// This is the preferred path for install-time shim creation (e.g., from a +// runtime provider's post-install hook). Bare CreateShims only writes the +// shim binaries to disk — it does not update the shim-map cache, which means +// subsequent shim invocations have to fall back to the provider registry +// lookup instead of the O(1) cache hit. Calling CreateShimsForRuntime keeps +// the shim files and the cache in sync from the moment they are created. +func (m *Manager) CreateShimsForRuntime(runtimeName string, shimNames []string) error { + if err := m.CreateShims(shimNames); err != nil { + return err + } + + entries := make(ShimMap, len(shimNames)) + for _, name := range shimNames { + entries[name] = runtimeName + } + + return MergeShimMap(entries) +} + // RemoveShim removes a shim func (m *Manager) RemoveShim(shimName string) error { shimPath := config.ShimPath(shimName) diff --git a/src/runtimes/node/provider.go b/src/runtimes/node/provider.go index 688a0ab..c137981 100644 --- a/src/runtimes/node/provider.go +++ b/src/runtimes/node/provider.go @@ -152,7 +152,9 @@ func (p *Provider) getDownloadURL(version string) (string, string, error) { return dl.URL, archiveName, nil } -// createShims creates shims for Node.js executables +// createShims creates shims for Node.js executables and registers them in the +// shim-map cache so subsequent shim invocations resolve via O(1) lookup rather +// than falling back to the provider registry. func (p *Provider) createShims() error { manager, err := shim.NewManager() if err != nil { @@ -162,8 +164,8 @@ func (p *Provider) createShims() error { // Get the list of shims for Node.js shimNames := shim.RuntimeShims("node") - // Create each shim - return manager.CreateShims(shimNames) + // Create each shim AND record them in the shim map cache + return manager.CreateShimsForRuntime("node", shimNames) } // Uninstall removes an installed version diff --git a/src/runtimes/python/provider.go b/src/runtimes/python/provider.go index 10086c8..5952d3c 100644 --- a/src/runtimes/python/provider.go +++ b/src/runtimes/python/provider.go @@ -209,7 +209,9 @@ func (p *Provider) getDownloadURL(version string) (string, string, error) { return dl.URL, archiveName, nil } -// createShims creates shims for Python executables +// createShims creates shims for Python executables and registers them in the +// shim-map cache so subsequent shim invocations resolve via O(1) lookup rather +// than falling back to the provider registry. func (p *Provider) createShims() error { manager, err := shim.NewManager() if err != nil { @@ -219,8 +221,8 @@ func (p *Provider) createShims() error { // Get the list of shims for Python shimNames := shim.RuntimeShims("python") - // Create each shim - return manager.CreateShims(shimNames) + // Create each shim AND record them in the shim map cache + return manager.CreateShimsForRuntime("python", shimNames) } // installPip ensures pip is properly installed with working executables. diff --git a/src/runtimes/ruby/provider.go b/src/runtimes/ruby/provider.go index 3730463..305b4f7 100644 --- a/src/runtimes/ruby/provider.go +++ b/src/runtimes/ruby/provider.go @@ -233,7 +233,9 @@ func (p *Provider) getDownloadURL(version string) (string, string, error) { return dl.URL, archiveName, nil } -// createShims creates shims for Ruby executables +// createShims creates shims for Ruby executables and registers them in the +// shim-map cache so subsequent shim invocations resolve via O(1) lookup rather +// than falling back to the provider registry. func (p *Provider) createShims() error { manager, err := shim.NewManager() if err != nil { @@ -243,8 +245,8 @@ func (p *Provider) createShims() error { // Get the list of shims for Ruby shimNames := shim.RuntimeShims("ruby") - // Create each shim - return manager.CreateShims(shimNames) + // Create each shim AND record them in the shim map cache + return manager.CreateShimsForRuntime("ruby", shimNames) } // Uninstall removes an installed version