[winsvc] Add interfaces for Windows services and service manager

Provides interfaces to the Windows service manager and Windows
services. These interfaces support creating new Windows services,
deleting Windows services, configuring Windows services, and
registering/deregistering services with Windows Eventlog.

A path helper is included to support expansion of paths using a
subset of known folder IDs.

A privileged helper is included to check that the process is
currently being executed with elevated privileges, which are
required for managing Windows services and modifying the registry.
This commit is contained in:
Chris Roberts
2025-08-05 13:50:26 -07:00
parent df3c74ff55
commit 48d91dc1f9
11 changed files with 1410 additions and 0 deletions

View File

@@ -13,6 +13,7 @@ project {
"ui/node_modules", "ui/node_modules",
"pnpm-workspace.yaml", "pnpm-workspace.yaml",
"pnpm-lock.yaml", "pnpm-lock.yaml",
"helper/winsvc/strings_*.go",
// Enterprise files do not fall under the open source licensing. CE-ENT // Enterprise files do not fall under the open source licensing. CE-ENT
// merge conflicts might happen here, please be sure to put new CE // merge conflicts might happen here, please be sure to put new CE

View File

@@ -0,0 +1,22 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !windows
package winsvc
import "errors"
func NewWindowsPaths() WindowsPaths {
return &windowsPaths{}
}
type windowsPaths struct{}
func (w *windowsPaths) Expand(string) (string, error) {
return "", errors.New("Windows path expansion not supported on this platform")
}
func (w *windowsPaths) CreateDirectory(string, bool) error {
return errors.New("Windows directory creation not supported on this platform")
}

View File

@@ -0,0 +1,206 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package winsvc
import (
"bytes"
"errors"
"fmt"
"os"
"strings"
"sync"
"text/template"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
func NewWindowsPaths() WindowsPaths {
return &windowsPaths{}
}
type windowsPaths struct {
SystemRoot string
SystemDrive string
ProgramData string
ProgramFiles string
loadErr error
o sync.Once
}
func (w *windowsPaths) Expand(path string) (string, error) {
if err := w.load(); err != nil {
return "", err
}
tmpl := template.New("expansion").Option("missingkey=error")
tmpl, err := tmpl.Parse(path)
if err != nil {
return "", err
}
result := new(bytes.Buffer)
if err := tmpl.Execute(result, w); err != nil {
return "", err
}
return result.String(), nil
}
func (w *windowsPaths) CreateDirectory(path string, restrict_on_create bool) error {
s, err := os.Stat(path)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if err == nil {
// Directory exists so nothing to do
if s.IsDir() {
return nil
}
return fmt.Errorf("path exists and is not directory - %s", path)
}
// NOTE: mode ignored on Windows. If directory should
// be restricted, an ACL will be applied below.
if err := os.MkdirAll(path, 0o000); err != nil {
return err
}
// Since the directory was just created, apply access
// restrictions if requested
if restrict_on_create {
if err := setDirectoryPermissions(path); err != nil {
return err
}
}
return nil
}
func getUserGroupSIDs() (usid *windows.SID, gsid *windows.SID, err error) {
// NOTE: this token is a pseudo-token and does not
// need to be closed
token := windows.GetCurrentProcessToken()
userToken, err := token.GetTokenUser()
if err != nil {
return
}
usid = userToken.User.Sid
userGroup, err := token.GetTokenPrimaryGroup()
if err != nil {
return
}
gsid = userGroup.PrimaryGroup
return
}
func setDirectoryPermissions(path string) error {
// Grab the user and group SID for who is running the process
userSid, groupSid, err := getUserGroupSIDs()
if err != nil {
return err
}
// Generate a SID for the administators group
gsid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
if err != nil {
return err
}
// Create an ACL with an ACE for user SID and an ACE for the
// administrators group SID, both of which are granted full
// access. No other ACEs are provided which restricts access
// from non-administrators
dacl, err := windows.ACLFromEntries(
[]windows.EXPLICIT_ACCESS{
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_USER,
TrusteeValue: windows.TrusteeValueFromSID(userSid),
},
},
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
TrusteeValue: windows.TrusteeValueFromSID(gsid),
},
},
}, nil,
)
if err != nil {
return err
}
// Apply the ACL to the directory
if err := windows.SetNamedSecurityInfo(path, windows.SE_FILE_OBJECT,
windows.OWNER_SECURITY_INFORMATION|
windows.GROUP_SECURITY_INFORMATION|
windows.DACL_SECURITY_INFORMATION|
windows.PROTECTED_DACL_SECURITY_INFORMATION,
userSid, groupSid, dacl, nil); err != nil {
return err
}
return nil
}
func (w *windowsPaths) load() error {
w.o.Do(func() {
w.SystemDrive = os.Getenv("SystemDrive")
if w.SystemDrive == "" {
w.loadErr = fmt.Errorf("cannot detect Windows SystemDrive path")
return
}
w.SystemRoot = strings.ReplaceAll(os.Getenv("SystemDrive"), "SystemDrive", w.SystemDrive)
w.ProgramData = os.Getenv("ProgramData")
if w.ProgramData == "" {
pdKey, err := registry.OpenKey(registry.LOCAL_MACHINE,
`SOFTWARE\Microsoft\Windows NT\CurrentVersion\ProfileList`, registry.QUERY_VALUE)
if err == nil {
if pdVal, _, err := pdKey.GetStringValue("ProgramData"); err == nil {
w.ProgramData = pdVal
}
}
}
if w.ProgramData == "" {
w.loadErr = fmt.Errorf("cannot detect Windows ProgramData path")
return
}
w.ProgramData = strings.ReplaceAll(w.ProgramData, "SystemDrive", w.SystemDrive)
w.ProgramFiles = os.Getenv("ProgramFiles")
if w.ProgramFiles == "" {
pdKey, err := registry.OpenKey(registry.LOCAL_MACHINE,
`SOFTWARE\Microsoft\Windows\CurrentVersion`, registry.QUERY_VALUE)
if err == nil {
if pdVal, _, err := pdKey.GetStringValue("ProgramFilesDir"); err == nil {
w.ProgramFiles = pdVal
}
}
}
if w.ProgramFiles == "" {
w.loadErr = fmt.Errorf("cannot detect Windows ProgramFiles path")
return
}
w.ProgramFiles = strings.ReplaceAll(w.ProgramFiles, "SystemDrive", w.SystemDrive)
})
return w.loadErr
}

View File

@@ -0,0 +1,203 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package winsvc
import (
"os"
"path/filepath"
"testing"
"unsafe"
"github.com/hashicorp/nomad/ci"
"github.com/shoenig/test/must"
"golang.org/x/sys/windows"
)
func TestCreateDirectory(t *testing.T) {
ci.Parallel(t)
testDir := t.TempDir()
t.Run("create", func(t *testing.T) {
// NOTE: parallel is not set here to force parent
// to wait for subtests to complete
t.Run("unrestricted", func(t *testing.T) {
ci.Parallel(t)
path := filepath.Join(testDir, t.Name())
err := NewWindowsPaths().CreateDirectory(path, false)
must.NoError(t, err)
dacl := getDirectoryDACL(t, path)
// When not applying restrictions on the new directory, all
// ACEs will be inherited from the parent
for i := range dacl.AceCount {
ace := &windows.ACCESS_ALLOWED_ACE{}
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
must.Eq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
must.Sprint("ACE is not inherited"))
}
})
t.Run("restricted", func(t *testing.T) {
ci.Parallel(t)
path := filepath.Join(testDir, t.Name())
err := NewWindowsPaths().CreateDirectory(path, true)
must.NoError(t, err)
dacl := getDirectoryDACL(t, path)
matches := map[string]struct{}{}
// When restrictions are applied on the new directory, all
// ACEs will be directly applied.
for i := range dacl.AceCount {
ace := &windows.ACCESS_ALLOWED_ACE{}
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
must.NotEq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
must.Sprint("ACE should not be inherited"))
if ace.Mask&windows.GENERIC_ALL == windows.GENERIC_ALL {
sid := (*windows.SID)(unsafe.Pointer(&ace.SidStart))
matches[sid.String()] = struct{}{}
}
}
// All privileges should be set for user and administrators groups
adminGroupSID, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
must.NoError(t, err, must.Sprint("failed to create well known administrators group SID"))
userSID, _, err := getUserGroupSIDs()
must.NoError(t, err, must.Sprint("failed to get user SID"))
must.NotNil(t, matches[userSID.String()], must.Sprint("missing user ACE with GENERIC_ALL"))
must.NotNil(t, matches[adminGroupSID.String()],
must.Sprint("missing administrators group ACE with GENERIC_ALL"))
must.Eq(t, 2, len(matches), must.Sprint("unexpected GENERIC_ALL ACEs found"))
})
t.Run("unrestricted already exists", func(t *testing.T) {
ci.Parallel(t)
path := filepath.Join(testDir, t.Name())
must.NoError(t, os.MkdirAll(path, 0o000))
err := NewWindowsPaths().CreateDirectory(path, false)
must.NoError(t, err)
dacl := getDirectoryDACL(t, path)
// No restrictions are applied, so check that all ACEs
// are inherited from parent
for i := range dacl.AceCount {
ace := &windows.ACCESS_ALLOWED_ACE{}
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
must.Eq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
must.Sprint("ACE is not inherited"))
}
})
t.Run("restricted already exists", func(t *testing.T) {
ci.Parallel(t)
path := filepath.Join(testDir, t.Name())
must.NoError(t, os.MkdirAll(path, 0o000))
err := NewWindowsPaths().CreateDirectory(path, true)
must.NoError(t, err)
dacl := getDirectoryDACL(t, path)
// When the directory already exists, restrictions should not
// be applied so validate that all ACEs are inherited
for i := range dacl.AceCount {
ace := &windows.ACCESS_ALLOWED_ACE{}
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
must.Eq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
must.Sprint("ACE is not inherited"))
}
})
})
}
func TestExpand(t *testing.T) {
t.Run("SystemDrive", func(t *testing.T) {
t.Run("default", func(t *testing.T) {
result, err := NewWindowsPaths().Expand(`{{.SystemDrive}}/testing`)
must.NoError(t, err)
must.StrNotContains(t, result, "{{.SystemDrive}}")
})
t.Run("custom environment variable", func(t *testing.T) {
t.Setenv("SystemDrive", `z:`)
result, err := NewWindowsPaths().Expand(`{{.SystemDrive}}\testing`)
must.NoError(t, err)
must.Eq(t, `z:\testing`, result)
})
t.Run("unset environment variable", func(t *testing.T) {
t.Setenv("SystemDrive", "")
_, err := NewWindowsPaths().Expand(`{{.SystemDrive}}\testing`)
must.ErrorContains(t, err, "cannot detect Windows SystemDrive path")
})
})
t.Run("ProgramData", func(t *testing.T) {
t.Run("default", func(t *testing.T) {
result, err := NewWindowsPaths().Expand(`{{.ProgramData}}/testing`)
must.NoError(t, err)
must.StrNotContains(t, result, "{{.ProgramData}}")
})
t.Run("custom environment variable", func(t *testing.T) {
t.Setenv("ProgramData", `z:`)
result, err := NewWindowsPaths().Expand(`{{.ProgramData}}\testing`)
must.NoError(t, err)
must.Eq(t, `z:\testing`, result)
})
t.Run("unset environment variable", func(t *testing.T) {
t.Setenv("ProgramData", "")
result, err := NewWindowsPaths().Expand(`{{.ProgramData}}\testing`)
must.NoError(t, err)
must.StrNotContains(t, result, "{{.ProgramData}}") // should be pulled from registry
})
})
t.Run("ProgramFiles", func(t *testing.T) {
t.Run("default", func(t *testing.T) {
result, err := NewWindowsPaths().Expand(`{{.ProgramFiles}}/testing`)
must.NoError(t, err)
must.StrNotContains(t, result, "{{.ProgramFiles}}")
})
t.Run("custom environment variable", func(t *testing.T) {
t.Setenv("ProgramFiles", `z:`)
result, err := NewWindowsPaths().Expand(`{{.ProgramFiles}}\testing`)
must.NoError(t, err)
must.Eq(t, `z:\testing`, result)
})
t.Run("unset environment variable", func(t *testing.T) {
t.Setenv("ProgramFiles", "")
result, err := NewWindowsPaths().Expand(`{{.ProgramFiles}}\testing`)
must.NoError(t, err)
must.StrNotContains(t, result, "{{.ProgramFiles}}") // should be pulled from registry
})
})
t.Run("missing key", func(t *testing.T) {
_, err := NewWindowsPaths().Expand(`{{.Unknown}}\testing`)
must.ErrorContains(t, err, "can't evaluate field")
})
}
func getDirectoryDACL(t *testing.T, path string) *windows.ACL {
t.Helper()
s, err := os.Stat(path)
must.NoError(t, err)
must.True(t, s.IsDir(), must.Sprint("expected path to be a directory"))
info, err := windows.GetNamedSecurityInfo(path,
windows.SE_FILE_OBJECT, windows.DACL_SECURITY_INFORMATION)
must.NoError(t, err, must.Sprint("failed to get path security information"))
dacl, _, err := info.DACL()
must.NoError(t, err, must.Sprint("failed to get path ACL"))
return dacl
}

View File

@@ -0,0 +1,11 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !windows
package winsvc
// IsPrivilegedProcess checks if current process is a privileged windows process
func IsPrivilegedProcess() bool {
return false
}

View File

@@ -0,0 +1,11 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package winsvc
import "golang.org/x/sys/windows"
// IsPrivilegedProcess checks if current process is a privileged windows process
func IsPrivilegedProcess() bool {
return windows.GetCurrentProcessToken().IsElevated()
}

View File

@@ -3,6 +3,18 @@
package winsvc package winsvc
const (
WINDOWS_SERVICE_NAME = "nomad"
WINDOWS_SERVICE_DISPLAY_NAME = "HashiCorp Nomad"
WINDOWS_SERVICE_DESCRIPTION = "Workload scheduler and orchestrator - https://nomadproject.io"
WINDOWS_INSTALL_BIN_DIRECTORY = `{{.ProgramFiles}}\HashiCorp\nomad\bin`
WINDOWS_INSTALL_APPDATA_DIRECTORY = `{{.ProgramData}}\HashiCorp\nomad`
// Number of seconds to wait for a
// service to reach a desired state
WINDOWS_SERVICE_STATE_TIMEOUT = "1m"
)
var chanGraceExit = make(chan int) var chanGraceExit = make(chan int)
// ShutdownChannel returns a channel that sends a message that a shutdown // ShutdownChannel returns a channel that sends a message that a shutdown

View File

@@ -0,0 +1,75 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package winsvc
type ServiceStartType uint32
// extracted from https://pkg.go.dev/golang.org/x/sys@v0.35.0/windows/svc/mgr#StartManual
const (
StartManual ServiceStartType = 3
StartAutomatic ServiceStartType = 2
StartDisabled ServiceStartType = 4
)
type WindowsServiceConfiguration struct {
StartType ServiceStartType
DisplayName string
Description string
BinaryPathName string
}
type WindowsPaths interface {
// Expand expands the path defined by the template. Supports
// values for:
// - SystemDrive
// - SystemRoot
// - ProgramData
// - ProgramFiles
Expand(path string) (string, error)
// Creates a new directory if it does not exist. If directory
// is created and restrict_on_create is true, a restrictive
// ACL is applied.
CreateDirectory(path string, restrict_on_create bool) error
}
type WindowsService interface {
// Name returns the name of the service
Name() string
// Configure applies the configuration to the Windows service.
// NOTE: Full configuration applied so empty values will remove existing values.
Configure(config WindowsServiceConfiguration) error
// Start starts the Windows service and waits for the
// service to be running.
Start() error
// Stop requests the service to stop and waits for the
// service to stop.
Stop() error
// Close closes the connection to the Windows service.
Close() error
// Delete deletes the Windows service.
Delete() error
// IsRunning returns if the service is currently running.
IsRunning() (bool, error)
// IsStopped returns if the service is currently stopped.
IsStopped() (bool, error)
// EnableEventlog will add or update the Windows Eventlog
// configuration for the service. It will set supported
// events as info, warning, and error.
EnableEventlog() error
// DisableEventlog will remove the Windows Eventlog configuration
// for the service.
DisableEventlog() error
}
type WindowsServiceManager interface {
// IsServiceRegistered returns if the service is a registered Windows service.
IsServiceRegistered(name string) (bool, error)
// GetService opens and returns the named service.
GetService(name string) (WindowsService, error)
// CreateService creates a new Windows service.
CreateService(name, binaryPath string, config WindowsServiceConfiguration) (WindowsService, error)
// Close closes Windows service manager connection.
Close() error
}

View File

@@ -0,0 +1,15 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !windows
package winsvc
import (
"errors"
)
// NewWindowsServiceManager returns an error
func NewWindowsServiceManager() (WindowsServiceManager, error) {
return nil, errors.New("Windows service manager is not supported on this platform")
}

View File

@@ -0,0 +1,256 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package winsvc
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"os/signal"
"reflect"
"slices"
"time"
"github.com/hashicorp/nomad/helper"
"golang.org/x/sys/windows/registry"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
// Base registry path for eventlog registrations
const EVENTLOG_REGISTRY_PATH = `SYSTEM\CurrentControlSet\Services\EventLog\Application`
// Registry value name for supported event types
const EVENTLOG_SUPPORTED_EVENTS_KEY = "TypesSupported"
// Event types registered as supported
const EVENTLOG_SUPPORTED_EVENTS uint32 = eventlog.Error | eventlog.Warning | eventlog.Info
// NewWindowsServiceManager creates a new instance of the wrapper
// to interact with the Windows service manager.
func NewWindowsServiceManager() (WindowsServiceManager, error) {
m, err := mgr.Connect()
if err != nil {
return nil, err
}
return &windowsServiceManager{manager: m}, nil
}
type windowsServiceManager struct {
manager *mgr.Mgr
}
func (m *windowsServiceManager) IsServiceRegistered(name string) (bool, error) {
list, err := m.manager.ListServices()
if err != nil {
return false, err
}
if slices.Contains(list, name) {
return true, nil
}
return false, nil
}
func (m *windowsServiceManager) GetService(name string) (WindowsService, error) {
service, err := m.manager.OpenService(name)
if err != nil {
return nil, err
}
return &windowsService{service: service}, nil
}
func (m *windowsServiceManager) CreateService(name, bin string, config WindowsServiceConfiguration) (WindowsService, error) {
wsvc, err := m.manager.CreateService(name, bin, mgr.Config{})
if err != nil {
return nil, err
}
service := &windowsService{service: wsvc}
// Only apply configuration if configuration is provided
if !reflect.ValueOf(config).IsZero() {
if err := service.Configure(config); err != nil {
return nil, err
}
}
return service, nil
}
func (m *windowsServiceManager) Close() error {
return m.manager.Disconnect()
}
type windowsService struct {
service *mgr.Service
}
func (s *windowsService) Name() string {
return s.service.Name
}
func (s *windowsService) Configure(config WindowsServiceConfiguration) error {
serviceCfg, err := s.service.Config()
if err != nil {
return err
}
serviceCfg.StartType = uint32(config.StartType)
serviceCfg.DisplayName = config.DisplayName
serviceCfg.Description = config.Description
serviceCfg.BinaryPathName = config.BinaryPathName
if err := s.service.UpdateConfig(serviceCfg); err != nil {
return err
}
return nil
}
func (s *windowsService) Start() error {
if running, _ := s.IsRunning(); running {
return nil
}
if err := s.service.Start(); err != nil {
return err
}
if err := waitFor(context.Background(), s.IsRunning); err != nil {
return err
}
return nil
}
func (s *windowsService) Stop() error {
if stopped, _ := s.IsStopped(); stopped {
return nil
}
if _, err := s.service.Control(svc.Stop); err != nil {
return err
}
if err := waitFor(context.Background(), s.IsStopped); err != nil {
return err
}
return nil
}
func (s *windowsService) Close() error {
return s.service.Close()
}
func (s *windowsService) Delete() error {
return s.service.Delete()
}
func (s *windowsService) IsRunning() (bool, error) {
return s.isService(svc.Running)
}
func (s *windowsService) IsStopped() (bool, error) {
return s.isService(svc.Stopped)
}
func (s *windowsService) EnableEventlog() error {
// Check if the service is already setup in the eventlog
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
EVENTLOG_REGISTRY_PATH+`\`+s.Name(),
registry.ALL_ACCESS,
)
// If it could not be opened, assume error is caused
// due to nonexistence. If it was for some other reason
// the error will be encountered again when attempting to
// create.
if err != nil {
if err := eventlog.InstallAsEventCreate(s.Name(), EVENTLOG_SUPPORTED_EVENTS); err != nil {
return err
}
} else {
defer key.Close()
// Since the service is already registered, just
// ensure it is properly configured. Currently
// that just means the supported events.
val, _, err := key.GetIntegerValue(EVENTLOG_SUPPORTED_EVENTS_KEY)
if err != nil || uint32(val) != EVENTLOG_SUPPORTED_EVENTS {
if err := key.SetDWordValue(EVENTLOG_SUPPORTED_EVENTS_KEY, EVENTLOG_SUPPORTED_EVENTS); err != nil {
return err
}
}
}
return nil
}
func (s *windowsService) DisableEventlog() error {
// Check if the service is currently enabled in the eventlog
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
EVENTLOG_REGISTRY_PATH+`\`+s.Name(),
registry.READ,
)
if errors.Is(err, fs.ErrNotExist) {
return nil
}
defer key.Close()
return eventlog.Remove(s.Name())
}
func (s *windowsService) isService(state svc.State) (bool, error) {
status, err := s.service.Query()
if err != nil {
return false, err
}
return status.State == state, nil
}
func waitFor(ctx context.Context, condition func() (bool, error)) error {
d, err := time.ParseDuration(WINDOWS_SERVICE_STATE_TIMEOUT)
if err != nil {
return err
}
// Setup a deadline
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(d))
defer cancel()
// Watch for any interrupts
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
defer stop()
pauseDur := time.Millisecond * 250
t, timerStop := helper.NewSafeTimer(pauseDur)
defer timerStop()
for {
t.Reset(pauseDur)
complete, err := condition()
if err != nil {
return err
}
if complete {
return nil
}
select {
case <-ctx.Done():
return fmt.Errorf("timeout exceeded waiting for process")
case <-t.C:
}
}
}

View File

@@ -0,0 +1,598 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package winsvc
import (
"context"
"io/fs"
"testing"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/nomad/ci"
"github.com/shoenig/test/must"
"golang.org/x/sys/windows/registry"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
)
func TestWindowsServiceManager(t *testing.T) {
ci.Parallel(t)
t.Run("IsServiceRegistered", func(t *testing.T) {
ci.Parallel(t)
t.Run("service does not exist", func(t *testing.T) {
ci.Parallel(t)
_, manager := makeManagers(t)
result, err := manager.IsServiceRegistered("fake-service-name")
must.NoError(t, err, must.Sprint("check should not error"))
must.False(t, result, must.Sprint("service should not exist"))
})
t.Run("service does exist", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
result, err := manager.IsServiceRegistered(serviceName)
must.NoError(t, err, must.Sprint("check should not error"))
must.True(t, result, must.Sprint("service should exist"))
})
})
t.Run("GetService", func(t *testing.T) {
ci.Parallel(t)
t.Run("service does not exist", func(t *testing.T) {
ci.Parallel(t)
_, manager := makeManagers(t)
_, err := manager.GetService("fake-service-name")
must.ErrorContains(t, err, "specified service does not exist",
must.Sprint("error should be generated when service does not exist"))
})
t.Run("service does exist", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err)
defer srv.Close()
must.Eq(t, serviceName, srv.Name(), must.Sprint("service name does not match"))
})
})
t.Run("CreateService", func(t *testing.T) {
ci.Parallel(t)
t.Run("service does not exist", func(t *testing.T) {
ci.Parallel(t)
serviceName := generateServiceName()
m, manager := makeManagers(t)
srv, err := manager.CreateService(serviceName, `c:\stub`, WindowsServiceConfiguration{})
must.NoError(t, err)
defer srv.Close()
defer deleteStubService(t, m, serviceName)
must.Eq(t, serviceName, srv.Name(), must.Sprint("new service name is incorrect"))
})
t.Run("service does exist", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
_, err := manager.CreateService(serviceName, `c:\stub`, WindowsServiceConfiguration{})
must.ErrorContains(t, err, "service already exists", must.Sprint("service creation should fail"))
})
t.Run("with configuration", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateServiceName()
srv, err := manager.CreateService(serviceName, `c:\stub`,
WindowsServiceConfiguration{DisplayName: "testing service", StartType: StartDisabled})
must.NoError(t, err, must.Sprint("service should be created"))
defer srv.Close()
defer deleteStubService(t, m, serviceName)
directSrv, err := m.OpenService(serviceName)
must.NoError(t, err, must.Sprint("direct service connection should succeed"))
defer directSrv.Close()
config, err := directSrv.Config()
must.NoError(t, err, must.Sprint("configuration should be available from service"))
must.Eq(t, "testing service", config.DisplayName, must.Sprint("new service name does not match"))
})
})
}
// This is a simple service available in Windows. It will
// be used to locate the executable so a test service can
// be created using it that will allow proper start/stop
// testing.
const TEST_WINDOWS_SERVICE = "SNMPTrap"
func TestWindowsService(t *testing.T) {
ci.Parallel(t)
mg, _ := makeManagers(t)
snmpSvc, err := mg.OpenService(TEST_WINDOWS_SERVICE)
must.NoError(t, err)
defer snmpSvc.Close()
snmpConfig, err := snmpSvc.Config()
must.NoError(t, err)
binPath := snmpConfig.BinaryPathName
t.Run("Name", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err)
defer srv.Close()
must.Eq(t, serviceName, srv.Name(), must.Sprint("service name does not match"))
})
t.Run("Configure", func(t *testing.T) {
ci.Parallel(t)
t.Run("valid configuration", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be available"))
err = srv.Configure(WindowsServiceConfiguration{
StartType: StartDisabled,
DisplayName: "testing display name",
BinaryPathName: `c:\stub -with -arguments`,
})
must.NoError(t, err, must.Sprint("valid configuration should not error"))
directSrv, err := m.OpenService(serviceName)
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
config, err := directSrv.Config()
must.NoError(t, err, must.Sprint("direct service config should be available"))
must.Eq(t, "testing display name", config.DisplayName, must.Sprint("display name does not match"))
must.Eq(t, `c:\stub -with -arguments`, config.BinaryPathName, must.Sprint("binary path name does not match"))
})
t.Run("invalid configuration", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be available"))
err = srv.Configure(WindowsServiceConfiguration{
DisplayName: "testing display name",
BinaryPathName: `c:\stub -with -arguments`,
})
must.ErrorContains(t, err, "parameter is incorrect", must.Sprint("invalid configuration should error"))
})
})
t.Run("Start", func(t *testing.T) {
ci.Parallel(t)
t.Run("when stopped", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Stopped {
_, err := directSrv.Control(svc.Stop)
must.NoError(t, err, must.Sprint("direct stop should not fail"))
err = waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Stopped, nil
})
must.NoError(t, err, must.Sprint("service must be stopped"))
}
must.NoError(t, runnableSvc.Start(), must.Sprint("service should start without error"))
status, err = directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
must.Eq(t, status.State, svc.Running, must.Sprint("service should be running"))
})
t.Run("when running", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Running {
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
err := waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Running, nil
})
must.NoError(t, err, must.Sprint("service must be running"))
}
must.NoError(t, runnableSvc.Start(), must.Sprint("service should start without error"))
status, err = directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
must.Eq(t, status.State, svc.Running, must.Sprint("service should be running"))
})
})
t.Run("Stop", func(t *testing.T) {
ci.Parallel(t)
t.Run("when stopped", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Stopped {
_, err := directSrv.Control(svc.Stop)
must.NoError(t, err, must.Sprint("direct stop should not fail"))
err = waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Stopped, nil
})
must.NoError(t, err, must.Sprint("service must be stopped"))
}
must.NoError(t, runnableSvc.Stop(), must.Sprint("service should stop without error"))
status, err = directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
must.Eq(t, status.State, svc.Stopped, must.Sprint("service should be stopped"))
})
t.Run("when running", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Running {
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
err := waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Running, nil
})
must.NoError(t, err, must.Sprint("service must be running"))
}
must.NoError(t, runnableSvc.Stop(), must.Sprint("service should stop without error"))
status, err = directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
must.Eq(t, status.State, svc.Stopped, must.Sprint("service should be stopped"))
})
})
t.Run("Delete", func(t *testing.T) {
ci.Parallel(t)
t.Run("when service exists", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be avaialble"))
defer srv.Close()
must.NoError(t, srv.Delete(), must.Sprint("service should be deleted"))
})
t.Run("when service does not exist", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be avaialble"))
defer srv.Close()
// Delete the service directly
directSrv, err := m.OpenService(serviceName)
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
must.NoError(t, directSrv.Delete(), must.Sprint("service should be deleted"))
must.ErrorContains(t, srv.Delete(), "marked for deletion",
must.Sprint("service should have already been deleted"))
})
})
t.Run("IsRunning", func(t *testing.T) {
ci.Parallel(t)
t.Run("when service is not running", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Stopped {
_, err := directSrv.Control(svc.Stop)
must.NoError(t, err, must.Sprint("direct stop should not fail"))
err = waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Stopped, nil
})
must.NoError(t, err, must.Sprint("service must be stopped"))
}
srv, err := manager.GetService(directSrv.Name)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
result, err := srv.IsRunning()
must.NoError(t, err, must.Sprint("running check should not error"))
must.False(t, result, must.Sprint("should not show service as running"))
})
t.Run("when service is running", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Running {
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
err := waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Running, nil
})
must.NoError(t, err, must.Sprint("service must be running"))
}
srv, err := manager.GetService(directSrv.Name)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
result, err := srv.IsRunning()
must.NoError(t, err, must.Sprint("running check should not error"))
must.True(t, result, must.Sprint("should show service as running"))
})
})
t.Run("IsStopped", func(t *testing.T) {
ci.Parallel(t)
t.Run("when service is not running", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Stopped {
_, err := directSrv.Control(svc.Stop)
must.NoError(t, err, must.Sprint("direct stop should not fail"))
err = waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Stopped, nil
})
must.NoError(t, err, must.Sprint("service must be stopped"))
}
srv, err := manager.GetService(directSrv.Name)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
result, err := srv.IsStopped()
must.NoError(t, err, must.Sprint("running check should not error"))
must.True(t, result, must.Sprint("should show service as stopped"))
})
t.Run("when service is running", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
runnableSvc := runnableService(t, manager, binPath)
directSrv, err := m.OpenService(runnableSvc.Name())
must.NoError(t, err, must.Sprint("direct service should be available"))
defer directSrv.Close()
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service status should be available"))
if status.State != svc.Running {
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
err := waitFor(context.Background(), func() (bool, error) {
status, err := directSrv.Query()
must.NoError(t, err, must.Sprint("direct service should be queryable"))
return status.State == svc.Running, nil
})
must.NoError(t, err, must.Sprint("service must be running"))
}
srv, err := manager.GetService(directSrv.Name)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
result, err := srv.IsStopped()
must.NoError(t, err, must.Sprint("running check should not error"))
must.False(t, result, must.Sprint("should not show service as stopped"))
})
})
t.Run("EnableEventLog", func(t *testing.T) {
ci.Parallel(t)
t.Run("when service is not registered", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
must.NoError(t, srv.EnableEventlog(), must.Sprint("could not enable eventlog"))
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
EVENTLOG_REGISTRY_PATH+`\`+serviceName,
registry.READ,
)
must.NoError(t, err, must.Sprint("registry key should be available"))
defer key.Close()
val, _, err := key.GetIntegerValue(EVENTLOG_SUPPORTED_EVENTS_KEY)
must.NoError(t, err, must.Sprint("registry key value should be available"))
must.Eq(t, EVENTLOG_SUPPORTED_EVENTS, uint32(val), must.Sprint("registry value should match"))
})
t.Run("when service is already registered", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
must.NoError(t, srv.EnableEventlog(), must.Sprint("could not enable eventlog"))
// Modify value in registry
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
EVENTLOG_REGISTRY_PATH+`\`+serviceName,
registry.ALL_ACCESS,
)
err = key.SetDWordValue(EVENTLOG_SUPPORTED_EVENTS_KEY, 1)
must.NoError(t, err, must.Sprint("could not modify registry value"))
// Now enable and verify value is correct
must.NoError(t, srv.EnableEventlog(), must.Sprint("failed to enable eventlog"))
val, _, err := key.GetIntegerValue(EVENTLOG_SUPPORTED_EVENTS_KEY)
must.NoError(t, err, must.Sprint("registry value should be available"))
must.Eq(t, EVENTLOG_SUPPORTED_EVENTS, uint32(val), must.Sprint("registry value should match"))
})
})
t.Run("DisableEventLog", func(t *testing.T) {
ci.Parallel(t)
t.Run("when service is not registered", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
must.NoError(t, srv.DisableEventlog(), must.Sprint("eventlog disable should not error"))
})
t.Run("when service is registered", func(t *testing.T) {
ci.Parallel(t)
m, manager := makeManagers(t)
serviceName := generateStubService(t, m)
srv, err := manager.GetService(serviceName)
must.NoError(t, err, must.Sprint("service should be available"))
defer srv.Close()
must.NoError(t, srv.EnableEventlog(), must.Sprint("eventlog enable should not error"))
must.NoError(t, srv.DisableEventlog(), must.Sprint("eventlog disable should not error"))
_, err = registry.OpenKey(registry.LOCAL_MACHINE,
EVENTLOG_REGISTRY_PATH+`\`+serviceName,
registry.READ,
)
must.ErrorIs(t, err, fs.ErrNotExist, must.Sprint("registry key should no longer exist"))
})
})
}
func generateServiceName() string {
id, err := uuid.GenerateUUID()
if err != nil {
panic(err)
}
return id[:5]
}
func generateStubService(t *testing.T, m *mgr.Mgr) string {
t.Helper()
id := generateServiceName()
_, err := m.CreateService(id, `c:\stub`, mgr.Config{})
must.NoError(t, err, must.Sprint("failed to generate stub service"))
t.Cleanup(func() { deleteStubService(t, m, id) })
return id
}
func deleteStubService(t *testing.T, m *mgr.Mgr, svcId string) {
t.Helper()
srvc, err := m.OpenService(svcId)
if err != nil {
// If the service doesn't exist, then deletion is done so not
// an error. Otherwise, force an error.
must.ErrorContains(t, err, "service does not exist", must.Sprint("failed to open service"))
return
}
status, err := srvc.Query()
must.NoError(t, err, must.Sprint("failed to query service"))
if status.State != svc.Stopped {
status, err = srvc.Control(svc.Stop)
must.NoError(t, err, must.Sprint("failed to stop service"))
err := waitFor(context.Background(), func() (bool, error) {
status, err := srvc.Query()
must.NoError(t, err, must.Sprint("failed to query service"))
return status.State == svc.Stopped, nil
})
must.NoError(t, err, must.Sprintf("could not stop service for deletion - %s", svcId))
}
if err := srvc.Delete(); err != nil {
must.ErrorContains(t, err, "service has been marked for deletion", must.Sprint("failed to delete service"))
}
}
func makeManagers(t *testing.T) (*mgr.Mgr, WindowsServiceManager) {
t.Helper()
winM, err := NewWindowsServiceManager()
must.NoError(t, err, must.Sprint("failed to create service manager"))
m, err := mgr.Connect()
must.NoError(t, err, must.Sprint("failed to connect to windows service manager"))
t.Cleanup(func() {
winM.Close()
m.Disconnect()
})
return m, winM
}
func runnableService(t *testing.T, m WindowsServiceManager, binPath string) WindowsService {
t.Helper()
runnableSvc, err := m.CreateService(generateServiceName(), binPath,
WindowsServiceConfiguration{StartType: StartManual, BinaryPathName: binPath})
must.NoError(t, err, must.Sprint("failed to create runnable service"))
t.Cleanup(func() { runnableSvc.Close() })
return runnableSvc
}