diff --git a/cmd/commands/activate.go b/cmd/commands/activate.go index 75159c23..99d7f674 100644 --- a/cmd/commands/activate.go +++ b/cmd/commands/activate.go @@ -18,6 +18,7 @@ package commands import ( "fmt" + "github.com/version-fox/vfox/internal/util" "os" "strings" "text/template" @@ -46,25 +47,18 @@ func activateCmd(ctx *cli.Context) error { manager := internal.NewSdkManager() defer manager.Close() - workToolVersion, err := toolset.NewToolVersion(manager.PathMeta.WorkingDirectory) + toolVersionMap, err := manager.FindToolVersion() if err != nil { return err } - if err = manager.ParseLegacyFile(func(sdkname, version string) { - if _, ok := workToolVersion.Record[sdkname]; !ok { - workToolVersion.Record[sdkname] = version - } - }); err != nil { - return err - } homeToolVersion, err := toolset.NewToolVersion(manager.PathMeta.HomePath) if err != nil { return err } - sdkEnvs, err := manager.EnvKeys(toolset.MultiToolVersions{ - workToolVersion, - homeToolVersion, + sdkEnvs, err := manager.EnvKeys([]*util.SortedMap[string, string]{ + toolVersionMap, + homeToolVersion.SortedMap, }, internal.ShellLocation) if err != nil { return err diff --git a/cmd/commands/env.go b/cmd/commands/env.go index 9962919a..553936ef 100644 --- a/cmd/commands/env.go +++ b/cmd/commands/env.go @@ -26,6 +26,8 @@ import ( "github.com/version-fox/vfox/internal/logger" "github.com/version-fox/vfox/internal/shell" "github.com/version-fox/vfox/internal/toolset" + "github.com/version-fox/vfox/internal/util" + "os" "path/filepath" ) @@ -109,6 +111,12 @@ func cleanTmp() error { return nil } +type flushCacheItem struct { + Version string + Path []string + Envs map[string]*string +} + func envFlag(ctx *cli.Context) error { shellName := ctx.String("shell") if shellName == "" { @@ -147,16 +155,9 @@ func envFlag(ctx *cli.Context) error { } func aggregateEnvKeys(manager *internal.Manager) (internal.SdkEnvs, error) { - workToolVersion, err := toolset.NewToolVersion(manager.PathMeta.WorkingDirectory) - if err != nil { - return nil, err - } - if err = manager.ParseLegacyFile(func(sdkname, version string) { - if _, ok := workToolVersion.Record[sdkname]; !ok { - workToolVersion.Record[sdkname] = version - } - }); err != nil { + toolVersionMap, err := manager.FindToolVersion() + if err != nil { return nil, err } @@ -164,44 +165,97 @@ func aggregateEnvKeys(manager *internal.Manager) (internal.SdkEnvs, error) { if err != nil { return nil, err } - defer curToolVersion.Save() // Add the working directory to the first - tvs := toolset.MultiToolVersions{workToolVersion, curToolVersion} + multiToolVersions := []*util.SortedMap[string, string]{ + toolVersionMap, + curToolVersion.SortedMap, + } - flushCache, err := cache.NewFileCache(filepath.Join(manager.PathMeta.CurTmpPath, "flush_env.cache")) + flushCacheFile := filepath.Join(manager.PathMeta.CurTmpPath, "flush_env.cache") + flushCache, err := cache.NewFileCache(flushCacheFile) if err != nil { return nil, err } defer flushCache.Close() - var sdkEnvs []*internal.SdkEnv + var ( + sdkEnvs []*internal.SdkEnv + finalSdks = make(map[string]*flushCacheItem) + cacheLen = flushCache.Len() + ) - tvs.FilterTools(func(name, version string) bool { - if lookupSdk, err := manager.LookupSdk(name); err == nil { - vv, ok := flushCache.Get(name) - if ok && string(vv) == version { - logger.Debugf("Hit cache, skip flush envrionment, %s@%s\n", name, version) - return true - } else { - logger.Debugf("No hit cache, name: %s cache: %s, expected: %s \n", name, string(vv), version) + for _, tv := range multiToolVersions { + if err = tv.ForEach(func(name string, version string) error { + if _, ok := finalSdks[name]; ok { + return nil } - v := internal.Version(version) - if keys, err := lookupSdk.EnvKeys(v, internal.ShellLocation); err == nil { - flushCache.Set(name, cache.Value(version), cache.NeverExpired) - - sdkEnvs = append(sdkEnvs, &internal.SdkEnv{ - Sdk: lookupSdk, Env: keys, - }) - - // If we encounter a .tool-versions file, it is valid for the entire shell session, - // unless we encounter the next .tool-versions file or manually switch to the use command. - curToolVersion.Record[name] = version - return true + if lookupSdk, err := manager.LookupSdk(name); err == nil { + vv, ok := flushCache.Get(name) + if ok { + item := flushCacheItem{} + if err = vv.Unmarshal(&item); err != nil { + _ = os.Remove(flushCacheFile) + flushCache.Clear() + } + if item.Version == version { + logger.Debugf("Hit cache, skip flush envrionment, %s@%s\n", name, version) + finalSdks[name] = &item + return nil + } + } else { + logger.Debugf("No hit cache, name: %s, expected: %s \n", name, version) + } + v := internal.Version(version) + if keys, err := lookupSdk.EnvKeys(v, internal.ShellLocation); err == nil { + item := flushCacheItem{ + Version: version, + Path: keys.Paths.Slice(), + Envs: keys.Variables, + } + value, _ := cache.NewValue(&item) + flushCache.Set(name, value, cache.NeverExpired) + finalSdks[name] = &item + } } + return nil + }); err != nil { + return nil, err } - return false - }) + } + // Remove the old cache + if cacheLen != len(finalSdks) { + for _, sdkname := range flushCache.Keys() { + item, ok := finalSdks[sdkname] + // Remove the corresponding environment variable + if !ok { + linkPath := filepath.Join(manager.PathMeta.CurTmpPath, sdkname) + logger.Debugf("Remove unused sdk link: %s\n", linkPath) + _ = os.Remove(linkPath) + cv, _ := flushCache.Get(sdkname) + item = &flushCacheItem{} + if err = cv.Unmarshal(&item); err == nil { + newEnvs := make(env.Vars) + for k, _ := range item.Envs { + newEnvs[k] = nil + } + item.Envs = newEnvs + item.Path = make([]string, 0) + } + flushCache.Remove(sdkname) + } + paths := env.NewPaths(env.EmptyPaths) + for _, p := range item.Path { + paths.Add(p) + } + sdkEnvs = append(sdkEnvs, &internal.SdkEnv{ + Sdk: nil, Env: &env.Envs{ + Variables: item.Envs, + Paths: paths, + }, + }) + } + } return sdkEnvs, nil } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index fc84b0e9..a88ee0ba 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -101,6 +101,28 @@ func (c *FileCache) Get(key string) (Value, bool) { return item.Val, true } +func (c *FileCache) Keys() []string { + c.mu.RLock() + defer c.mu.RUnlock() + keys := make([]string, 0, len(c.items)) + for k := range c.items { + keys = append(keys, k) + } + return keys +} + +func (c *FileCache) Len() int { + return len(c.items) +} + +func (c *FileCache) Clear() int { + c.mu.Lock() + defer c.mu.Unlock() + len := len(c.items) + c.items = make(map[string]Item) + return len +} + // Remove a key from the cache func (c *FileCache) Remove(key string) { c.mu.Lock() diff --git a/internal/manager.go b/internal/manager.go index 947d84f1..0772bcaf 100644 --- a/internal/manager.go +++ b/internal/manager.go @@ -28,6 +28,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "github.com/mitchellh/go-ps" "github.com/pterm/pterm" @@ -67,13 +68,13 @@ type Manager struct { Config *config.Config } -func (m *Manager) EnvKeys(tvs toolset.MultiToolVersions, location Location) (SdkEnvs, error) { +func (m *Manager) EnvKeys(tvs []*util.SortedMap[string, string], location Location) (SdkEnvs, error) { var sdkEnvs SdkEnvs tools := make(map[string]struct{}) for _, t := range tvs { - for name, version := range t.Record { + _ = t.ForEach(func(name string, version string) error { if _, ok := tools[name]; ok { - continue + return nil } if lookupSdk, err := m.LookupSdk(name); err == nil { v := Version(version) @@ -86,7 +87,8 @@ func (m *Manager) EnvKeys(tvs toolset.MultiToolVersions, location Location) (Sdk }) } } - } + return nil + }) } return sdkEnvs, nil } @@ -216,7 +218,7 @@ func (m *Manager) Remove(pluginName string) error { return err } for _, filename := range source.Plugin.LegacyFilenames { - delete(lfr.Record, filename) + lfr.Remove(filename) } if err = lfr.Save(); err != nil { return fmt.Errorf("remove legacy filenames failed: %w", err) @@ -301,10 +303,10 @@ func (m *Manager) Update(pluginName string) error { return err } for _, filename := range sdk.Plugin.LegacyFilenames { - delete(lfr.Record, filename) + lfr.Remove(filename) } for _, filename := range tempPlugin.LegacyFilenames { - lfr.Record[filename] = pluginName + lfr.Set(filename, pluginName) } if err = lfr.Save(); err != nil { return fmt.Errorf("update legacy filenames failed: %w", err) @@ -458,7 +460,7 @@ func (m *Manager) Add(pluginName, url, alias string) error { return err } for _, filename := range tempPlugin.LegacyFilenames { - lfr.Record[filename] = pname + lfr.Set(filename, pname) } if err = lfr.Save(); err != nil { return fmt.Errorf("add legacy filenames failed: %w", err) @@ -654,8 +656,8 @@ func (m *Manager) ParseLegacyFile(output func(sdkname, version string)) error { } // There are some legacy files to be parsed. - if len(legacyFileRecord.Record) > 0 { - for filename, sdkname := range legacyFileRecord.Record { + if legacyFileRecord.Len() > 0 { + _ = legacyFileRecord.ForEach(func(filename string, sdkname string) error { path := filepath.Join(m.PathMeta.WorkingDirectory, filename) if util.FileExists(path) { logger.Debugf("Parsing legacy file %s \n", path) @@ -669,12 +671,83 @@ func (m *Manager) ParseLegacyFile(output func(sdkname, version string)) error { } } } - } + return nil + }) } return nil } +func (m *Manager) FindToolVersion() (*util.SortedMap[string, string], error) { + path := m.PathMeta.WorkingDirectory + legacyFileRecord, err := m.loadLegacyFileRecord() + if err != nil { + return nil, err + } + versionMap := util.NewSortedMap[string, string]() + for { + logger.Debugf("Find tool version in %s \n", pterm.LightBlue(path)) + file := filepath.Join(path, toolset.ToolVersionFilename) + if util.FileExists(file) { + logger.Debugf("Parsing tool version file: %s \n", file) + wtv, err := toolset.NewFileRecord(file) + if err != nil { + return nil, err + } + versionMap = wtv.SortedMap + } + if m.Config.LegacyVersionFile.Enable { + var wg sync.WaitGroup + for _, filename := range legacyFileRecord.Keys() { + wg.Add(1) + go func(fn string) { + defer wg.Done() + sdkname, _ := legacyFileRecord.Get(fn) + legacyFile := filepath.Join(m.PathMeta.WorkingDirectory, fn) + if util.FileExists(legacyFile) { + logger.Debugf("Parsing legacy file %s \n", legacyFile) + if sdk, err := m.LookupSdk(sdkname); err == nil { + // The .tool-version in the current directory has the highest priority, + // checking to see if the version information in the legacy file exists in the former, + // and updating to the former record (Don’t fall into the file!) if it doesn't. + if version, err := sdk.ParseLegacyFile(legacyFile); err == nil && version != "" { + logger.Debugf("Found %s@%s in %s \n", sdkname, version, legacyFile) + versionMap.Set(sdkname, string(version)) + } + } + } + }(filename) + } + wg.Wait() + } + if versionMap.Len() != 0 { + break + } + + parent := filepath.Dir(path) + if parent == path { + logger.Debugf("Reach root directory, stop searching: %s \n", parent) + + if versionMap.Len() == 0 { + logger.Debugf("%s \n", pterm.LightRed("Toolchain version not found, use globally version.")) + file := filepath.Join(m.PathMeta.HomePath, toolset.ToolVersionFilename) + if util.FileExists(file) { + logger.Debugf("Parsing tool version file: %s \n", file) + wtv, err := toolset.NewFileRecord(file) + if err != nil { + return nil, err + } + versionMap = wtv.SortedMap + } + } + break + } + path = parent + } + + return versionMap, nil +} + func NewSdkManager() *Manager { meta, err := newPathMeta() if err != nil { diff --git a/internal/sdk.go b/internal/sdk.go index e5a06773..6accdd19 100644 --- a/internal/sdk.go +++ b/internal/sdk.go @@ -40,10 +40,6 @@ import ( "github.com/version-fox/vfox/internal/util" ) -var ( - localSdkPackageCache = make(map[Version]*Package) -) - type Version string type SdkEnv struct { @@ -72,7 +68,8 @@ type Sdk struct { sdkManager *Manager Plugin *LuaPlugin // current sdk install path - InstallPath string + InstallPath string + localSdkPackageCache map[Version]*Package } func (b *Sdk) Install(version Version) error { @@ -245,7 +242,7 @@ func (b *Sdk) Uninstall(version Version) (err error) { if err != nil { return err } - delete(tv.Record, b.Plugin.SdkName) + tv.Remove(b.Plugin.SdkName) _ = tv.Save() err = os.RemoveAll(path) @@ -398,7 +395,7 @@ func (b *Sdk) Use(version Version, scope UseScope) error { } // clear global env - if oldVersion, ok := toolVersion.Record[b.Plugin.SdkName]; ok { + if oldVersion, ok := toolVersion.Get(b.Plugin.SdkName); ok { b.clearGlobalEnv(Version(oldVersion)) } if err = b.sdkManager.EnvManager.Load(keys); err != nil { @@ -408,7 +405,7 @@ func (b *Sdk) Use(version Version, scope UseScope) error { if err != nil { return err } - toolVersion.Record[b.Plugin.SdkName] = string(version) + toolVersion.Set(b.Plugin.SdkName, string(version)) if err = toolVersion.Save(); err != nil { return fmt.Errorf("failed to save tool versions, err:%w", err) } @@ -420,11 +417,11 @@ func (b *Sdk) Use(version Version, scope UseScope) error { func (b *Sdk) useInHook(version Version, scope UseScope) error { var ( - multiToolVersion toolset.MultiToolVersions + toolVersion toolset.ToolVersion ) if scope == Global { - toolVersion, err := toolset.NewToolVersion(b.sdkManager.PathMeta.HomePath) + tv, err := toolset.NewToolVersion(b.sdkManager.PathMeta.HomePath) if err != nil { return fmt.Errorf("failed to read tool versions, err:%w", err) } @@ -446,7 +443,7 @@ func (b *Sdk) useInHook(version Version, scope UseScope) error { // clear global env logger.Debugf("Clear global env: %s\n", b.Plugin.SdkName) - if oldVersion, ok := toolVersion.Record[b.Plugin.SdkName]; ok { + if oldVersion, ok := tv.Get(b.Plugin.SdkName); ok { b.clearGlobalEnv(Version(oldVersion)) } @@ -457,29 +454,28 @@ func (b *Sdk) useInHook(version Version, scope UseScope) error { if err != nil { return err } - multiToolVersion = append(multiToolVersion, toolVersion) + toolVersion = tv } else if scope == Project { - toolVersion, err := toolset.NewToolVersion(b.sdkManager.PathMeta.WorkingDirectory) + tv, err := toolset.NewToolVersion(b.sdkManager.PathMeta.WorkingDirectory) if err != nil { return fmt.Errorf("failed to read tool versions, err:%w", err) } - logger.Debugf("Load project toolchain versions: %v\n", toolVersion.Record) - multiToolVersion = append(multiToolVersion, toolVersion) - } + toolVersion = tv + } else { + tv, err := toolset.NewToolVersion(b.sdkManager.PathMeta.CurTmpPath) + if err != nil { + return fmt.Errorf("failed to read tool versions, err:%w", err) + } + toolVersion = tv - // It must also be saved once at the session level. - toolVersion, err := toolset.NewToolVersion(b.sdkManager.PathMeta.CurTmpPath) - if err != nil { - return fmt.Errorf("failed to read tool versions, err:%w", err) } - multiToolVersion = append(multiToolVersion, toolVersion) + toolVersion.Set(b.Plugin.SdkName, string(version)) - multiToolVersion.Add(b.Plugin.SdkName, string(version)) - - if err = multiToolVersion.Save(); err != nil { + if err := toolVersion.Save(); err != nil { return fmt.Errorf("failed to save tool versions, err:%w", err) } - if err = b.ToLinkPackage(version, ShellLocation); err != nil { + + if err := b.ToLinkPackage(version, ShellLocation); err != nil { return err } @@ -572,7 +568,7 @@ func (b *Sdk) clearGlobalEnv(version Version) { } func (b *Sdk) GetLocalSdkPackage(version Version) (*Package, error) { - p, ok := localSdkPackageCache[version] + p, ok := b.localSdkPackageCache[version] if ok { return p, nil } @@ -610,7 +606,7 @@ func (b *Sdk) GetLocalSdkPackage(version Version) (*Package, error) { Main: main, Additions: additions, } - localSdkPackageCache[version] = p2 + b.localSdkPackageCache[version] = p2 return p2, nil } @@ -748,7 +744,7 @@ func (b *Sdk) ClearCurrentEnv() error { return err } for _, tv := range toolVersion { - delete(tv.Record, b.Plugin.SdkName) + tv.Remove(b.Plugin.SdkName) } return nil } @@ -760,8 +756,9 @@ func NewSdk(manager *Manager, pluginPath string) (*Sdk, error) { return nil, fmt.Errorf("failed to create lua plugin: %w", err) } return &Sdk{ - sdkManager: manager, - InstallPath: filepath.Join(manager.PathMeta.SdkCachePath, strings.ToLower(luaPlugin.SdkName)), - Plugin: luaPlugin, + sdkManager: manager, + InstallPath: filepath.Join(manager.PathMeta.SdkCachePath, strings.ToLower(luaPlugin.SdkName)), + Plugin: luaPlugin, + localSdkPackageCache: make(map[Version]*Package), }, nil } diff --git a/internal/toolset/file_record.go b/internal/toolset/file_record.go index b0c727d3..c294a1af 100644 --- a/internal/toolset/file_record.go +++ b/internal/toolset/file_record.go @@ -26,26 +26,27 @@ import ( // FileRecord is a file that contains a map of string to string type FileRecord struct { - Record map[string]string - Path string + *util.SortedMap[string, string] + path string isInitEmpty bool } func (m *FileRecord) Save() error { - if m.isInitEmpty && len(m.Record) == 0 { + if m.isInitEmpty && m.Len() == 0 { return nil } - file, err := os.Create(m.Path) + file, err := os.Create(m.path) if err != nil { - return fmt.Errorf("failed to create file record %s: %w", m.Path, err) + return fmt.Errorf("failed to create file record %s: %w", m.path, err) } defer file.Close() - for k, v := range m.Record { - _, err := fmt.Fprintf(file, "%s %s\n", k, v) - if err != nil { - return err - } + err = m.ForEach(func(k string, v string) error { + _, err = fmt.Fprintf(file, "%s %s\n", k, v) + return err + }) + if err != nil { + return err } return nil } @@ -53,7 +54,7 @@ func (m *FileRecord) Save() error { // NewFileRecord creates a new FileRecord from a file // if the file does not exist, an empty FileRecord is returned func NewFileRecord(path string) (*FileRecord, error) { - versionsMap := make(map[string]string) + versionsMap := util.NewSortedMap[string, string]() if util.FileExists(path) { file, err := os.Open(path) if err != nil { @@ -65,7 +66,7 @@ func NewFileRecord(path string) (*FileRecord, error) { line := scanner.Text() parts := strings.Split(line, " ") if len(parts) == 2 { - versionsMap[parts[0]] = parts[1] + versionsMap.Set(parts[0], parts[1]) } } @@ -74,8 +75,8 @@ func NewFileRecord(path string) (*FileRecord, error) { } } return &FileRecord{ - Record: versionsMap, - Path: path, - isInitEmpty: len(versionsMap) == 0, + SortedMap: versionsMap, + path: path, + isInitEmpty: versionsMap.Len() == 0, }, nil } diff --git a/internal/toolset/file_record_test.go b/internal/toolset/file_record_test.go new file mode 100644 index 00000000..552a5f55 --- /dev/null +++ b/internal/toolset/file_record_test.go @@ -0,0 +1,80 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package toolset + +import ( + "os" + "testing" +) + +func TestFileRecord(t *testing.T) { + // Create a temporary file for testing + file, err := os.CreateTemp("", "test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + + fm, _ := NewFileRecord(file.Name()) + + // Test Set method + fm.Set("key", "value") + + // Test Get method + if v, ok := fm.Get("key"); !ok || v != "value" { + t.Errorf("Expected 'value', got %s", v) + } + + // Test Contains method + if !fm.Contains("key") { + t.Errorf("Expected true, got false") + } + + // Test Len method + if fm.Len() != 1 { + t.Errorf("Expected 1, got %d", fm.Len()) + } + + // Test ForEach method + err = fm.ForEach(func(k string, v string) error { + if k != "key" || v != "value" { + t.Errorf("Expected key: 'key', value: 'value', got key: %s, value: %s", k, v) + } + return nil + }) + if err != nil { + t.Errorf("ForEach method failed with error: %v", err) + } + + // Test Save method + err = fm.Save() + if err != nil { + t.Errorf("Save method failed with error: %v", err) + } + + // Test Remove method + v := fm.Remove("key") + if v != "value" { + t.Errorf("Expected 'value', got %s", v) + } + if fm.Contains("key") { + t.Errorf("Expected false, got true") + } + if fm.Len() != 0 { + t.Errorf("Expected 0, got %d", fm.Len()) + } +} diff --git a/internal/toolset/tool_version.go b/internal/toolset/tool_version.go index 1f3fb337..e9b8ab62 100644 --- a/internal/toolset/tool_version.go +++ b/internal/toolset/tool_version.go @@ -21,28 +21,29 @@ import ( "path/filepath" ) -const filename = ".tool-versions" +const ToolVersionFilename = ".tool-versions" -type MultiToolVersions []*ToolVersion +type MultiToolVersions []ToolVersion // FilterTools filters tools by the given filter function // and return the first one you find. func (m MultiToolVersions) FilterTools(filter func(name, version string) bool) map[string]string { tools := make(map[string]string) for _, t := range m { - for name, version := range t.Record { + _ = t.ForEach(func(name string, version string) error { _, ok := tools[name] if !ok && filter(name, version) { tools[name] = version } - } + return nil + }) } return tools } func (m MultiToolVersions) Add(name, version string) { for _, t := range m { - t.Record[name] = version + t.Set(name, version) } } @@ -56,19 +57,15 @@ func (m MultiToolVersions) Save() error { } // ToolVersion represents a .tool-versions file -type ToolVersion struct { - *FileRecord -} +type ToolVersion = *FileRecord -func NewToolVersion(dirPath string) (*ToolVersion, error) { - file := filepath.Join(dirPath, filename) +func NewToolVersion(dirPath string) (ToolVersion, error) { + file := filepath.Join(dirPath, ToolVersionFilename) mapFile, err := NewFileRecord(file) if err != nil { return nil, fmt.Errorf("failed to read tool versions file %s: %w", file, err) } - return &ToolVersion{ - FileRecord: mapFile, - }, nil + return mapFile, nil } func NewMultiToolVersions(paths []string) (MultiToolVersions, error) { diff --git a/internal/util/map.go b/internal/util/map.go new file mode 100644 index 00000000..243a700d --- /dev/null +++ b/internal/util/map.go @@ -0,0 +1,92 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +//type Map[K comparable, V any] interface { +// Get(k K) (V, bool) +// Set(k K, v V) bool +// Remove(k K) V +// Contains(k K) bool +// Len() int +// ForEach(f func(k K, v V) error) error +//} + +type SortedMap[K comparable, V any] struct { + keys []K + vals map[K]V +} + +func (s *SortedMap[K, V]) Get(k K) (V, bool) { + v, ok := s.vals[k] + return v, ok +} + +func (s *SortedMap[K, V]) Set(k K, v V) bool { + _, exists := s.vals[k] + if !exists { + s.keys = append(s.keys, k) + } + s.vals[k] = v + return !exists +} + +func (s *SortedMap[K, V]) Remove(k K) V { + v := s.vals[k] + delete(s.vals, k) + for i, key := range s.keys { + if key == k { + s.keys = append(s.keys[:i], s.keys[i+1:]...) + break + } + } + return v +} + +func (s *SortedMap[K, V]) Contains(k K) bool { + _, exists := s.vals[k] + return exists +} + +func (s *SortedMap[K, V]) Len() int { + return len(s.keys) +} + +func (s *SortedMap[K, V]) Keys() []K { + return append([]K{}, s.keys...) +} + +func (s *SortedMap[K, V]) Merge(sortedMap *SortedMap[K, V]) { + for _, k := range sortedMap.keys { + s.Set(k, sortedMap.vals[k]) + } +} + +func (s *SortedMap[K, V]) ForEach(f func(k K, v V) error) error { + for _, k := range s.keys { + if err := f(k, s.vals[k]); err != nil { + return err + } + } + return nil +} + +func NewSortedMap[K comparable, V any]() *SortedMap[K, V] { + return &SortedMap[K, V]{ + keys: []K{}, + vals: make(map[K]V), + } +} diff --git a/internal/util/map_test.go b/internal/util/map_test.go new file mode 100644 index 00000000..b5385068 --- /dev/null +++ b/internal/util/map_test.go @@ -0,0 +1,95 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "fmt" + "testing" +) + +func TestSortedMap(t *testing.T) { + sm := NewSortedMap[int, string]() + + // Test Set method + sm.Set(1, "one") + + // Test Get method + if _, ok := sm.Get(1); !ok { + t.Errorf("Expected true, got false") + } + + // Test Contains method + if !sm.Contains(1) { + t.Errorf("Expected true, got false") + } + + // Test Len method + if sm.Len() != 1 { + t.Errorf("Expected 1, got %d", sm.Len()) + } + + // Test ForEach method + err := sm.ForEach(func(k int, v string) error { + if k != 1 || v != "one" { + return fmt.Errorf("Expected key: 1, value: 'one', got key: %d, value: %s", k, v) + } + return nil + }) + if err != nil { + t.Errorf(err.Error()) + } + + // Test Remove method + val := sm.Remove(1) + if val != "one" { + t.Errorf("Expected 'one', got %s", val) + } + if sm.Contains(1) { + t.Errorf("Expected false, got true") + } + if sm.Len() != 0 { + t.Errorf("Expected 0, got %d", sm.Len()) + } +} + +func TestSortedMap_ForEach(t *testing.T) { + sm := NewSortedMap[int, string]() + sm.Set(1, "one") + sm.Set(2, "two") + sm.Set(3, "three") + + var keys []int + var values []string + _ = sm.ForEach(func(k int, v string) error { + keys = append(keys, k) + values = append(values, v) + return nil + }) + + if len(keys) != 3 { + t.Errorf("Expected 3, got %d", len(keys)) + } + if len(values) != 3 { + t.Errorf("Expected 3, got %d", len(values)) + } + if keys[0] != 1 || keys[1] != 2 || keys[2] != 3 { + t.Errorf("Expected [1 2 3], got %v", keys) + } + if values[0] != "one" || values[1] != "two" || values[2] != "three" { + t.Errorf("Expected ['one' 'two' 'three'], got %v", values) + } +}