mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 07:55:42 +03:00
[winsvc] Add interfaces for Windows services and service manager
Provides interfaces to the Windows service manager and Windows services. These interfaces support creating new Windows services, deleting Windows services, configuring Windows services, and registering/deregistering services with Windows Eventlog. A path helper is included to support expansion of paths using a subset of known folder IDs. A privileged helper is included to check that the process is currently being executed with elevated privileges, which are required for managing Windows services and modifying the registry.
This commit is contained in:
@@ -13,6 +13,7 @@ project {
|
||||
"ui/node_modules",
|
||||
"pnpm-workspace.yaml",
|
||||
"pnpm-lock.yaml",
|
||||
"helper/winsvc/strings_*.go",
|
||||
|
||||
// Enterprise files do not fall under the open source licensing. CE-ENT
|
||||
// merge conflicts might happen here, please be sure to put new CE
|
||||
|
||||
22
helper/winsvc/path_nonwindows.go
Normal file
22
helper/winsvc/path_nonwindows.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
//go:build !windows
|
||||
|
||||
package winsvc
|
||||
|
||||
import "errors"
|
||||
|
||||
func NewWindowsPaths() WindowsPaths {
|
||||
return &windowsPaths{}
|
||||
}
|
||||
|
||||
type windowsPaths struct{}
|
||||
|
||||
func (w *windowsPaths) Expand(string) (string, error) {
|
||||
return "", errors.New("Windows path expansion not supported on this platform")
|
||||
}
|
||||
|
||||
func (w *windowsPaths) CreateDirectory(string, bool) error {
|
||||
return errors.New("Windows directory creation not supported on this platform")
|
||||
}
|
||||
206
helper/winsvc/path_windows.go
Normal file
206
helper/winsvc/path_windows.go
Normal file
@@ -0,0 +1,206 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package winsvc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
func NewWindowsPaths() WindowsPaths {
|
||||
return &windowsPaths{}
|
||||
}
|
||||
|
||||
type windowsPaths struct {
|
||||
SystemRoot string
|
||||
SystemDrive string
|
||||
ProgramData string
|
||||
ProgramFiles string
|
||||
loadErr error
|
||||
o sync.Once
|
||||
}
|
||||
|
||||
func (w *windowsPaths) Expand(path string) (string, error) {
|
||||
if err := w.load(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tmpl := template.New("expansion").Option("missingkey=error")
|
||||
tmpl, err := tmpl.Parse(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result := new(bytes.Buffer)
|
||||
if err := tmpl.Execute(result, w); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
func (w *windowsPaths) CreateDirectory(path string, restrict_on_create bool) error {
|
||||
s, err := os.Stat(path)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Directory exists so nothing to do
|
||||
if s.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("path exists and is not directory - %s", path)
|
||||
}
|
||||
|
||||
// NOTE: mode ignored on Windows. If directory should
|
||||
// be restricted, an ACL will be applied below.
|
||||
if err := os.MkdirAll(path, 0o000); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since the directory was just created, apply access
|
||||
// restrictions if requested
|
||||
if restrict_on_create {
|
||||
if err := setDirectoryPermissions(path); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUserGroupSIDs() (usid *windows.SID, gsid *windows.SID, err error) {
|
||||
// NOTE: this token is a pseudo-token and does not
|
||||
// need to be closed
|
||||
token := windows.GetCurrentProcessToken()
|
||||
|
||||
userToken, err := token.GetTokenUser()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
usid = userToken.User.Sid
|
||||
|
||||
userGroup, err := token.GetTokenPrimaryGroup()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
gsid = userGroup.PrimaryGroup
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func setDirectoryPermissions(path string) error {
|
||||
// Grab the user and group SID for who is running the process
|
||||
userSid, groupSid, err := getUserGroupSIDs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate a SID for the administators group
|
||||
gsid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create an ACL with an ACE for user SID and an ACE for the
|
||||
// administrators group SID, both of which are granted full
|
||||
// access. No other ACEs are provided which restricts access
|
||||
// from non-administrators
|
||||
dacl, err := windows.ACLFromEntries(
|
||||
[]windows.EXPLICIT_ACCESS{
|
||||
{
|
||||
AccessPermissions: windows.GENERIC_ALL,
|
||||
AccessMode: windows.SET_ACCESS,
|
||||
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||
Trustee: windows.TRUSTEE{
|
||||
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
|
||||
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||
TrusteeType: windows.TRUSTEE_IS_USER,
|
||||
TrusteeValue: windows.TrusteeValueFromSID(userSid),
|
||||
},
|
||||
},
|
||||
{
|
||||
AccessPermissions: windows.GENERIC_ALL,
|
||||
AccessMode: windows.SET_ACCESS,
|
||||
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
|
||||
Trustee: windows.TRUSTEE{
|
||||
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
|
||||
TrusteeForm: windows.TRUSTEE_IS_SID,
|
||||
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
|
||||
TrusteeValue: windows.TrusteeValueFromSID(gsid),
|
||||
},
|
||||
},
|
||||
}, nil,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply the ACL to the directory
|
||||
if err := windows.SetNamedSecurityInfo(path, windows.SE_FILE_OBJECT,
|
||||
windows.OWNER_SECURITY_INFORMATION|
|
||||
windows.GROUP_SECURITY_INFORMATION|
|
||||
windows.DACL_SECURITY_INFORMATION|
|
||||
windows.PROTECTED_DACL_SECURITY_INFORMATION,
|
||||
userSid, groupSid, dacl, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *windowsPaths) load() error {
|
||||
w.o.Do(func() {
|
||||
w.SystemDrive = os.Getenv("SystemDrive")
|
||||
if w.SystemDrive == "" {
|
||||
w.loadErr = fmt.Errorf("cannot detect Windows SystemDrive path")
|
||||
return
|
||||
}
|
||||
w.SystemRoot = strings.ReplaceAll(os.Getenv("SystemDrive"), "SystemDrive", w.SystemDrive)
|
||||
|
||||
w.ProgramData = os.Getenv("ProgramData")
|
||||
if w.ProgramData == "" {
|
||||
pdKey, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows NT\CurrentVersion\ProfileList`, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
if pdVal, _, err := pdKey.GetStringValue("ProgramData"); err == nil {
|
||||
w.ProgramData = pdVal
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.ProgramData == "" {
|
||||
w.loadErr = fmt.Errorf("cannot detect Windows ProgramData path")
|
||||
return
|
||||
}
|
||||
w.ProgramData = strings.ReplaceAll(w.ProgramData, "SystemDrive", w.SystemDrive)
|
||||
|
||||
w.ProgramFiles = os.Getenv("ProgramFiles")
|
||||
if w.ProgramFiles == "" {
|
||||
pdKey, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows\CurrentVersion`, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
if pdVal, _, err := pdKey.GetStringValue("ProgramFilesDir"); err == nil {
|
||||
w.ProgramFiles = pdVal
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.ProgramFiles == "" {
|
||||
w.loadErr = fmt.Errorf("cannot detect Windows ProgramFiles path")
|
||||
return
|
||||
}
|
||||
w.ProgramFiles = strings.ReplaceAll(w.ProgramFiles, "SystemDrive", w.SystemDrive)
|
||||
})
|
||||
|
||||
return w.loadErr
|
||||
}
|
||||
203
helper/winsvc/path_windows_test.go
Normal file
203
helper/winsvc/path_windows_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package winsvc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/nomad/ci"
|
||||
"github.com/shoenig/test/must"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestCreateDirectory(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
testDir := t.TempDir()
|
||||
|
||||
t.Run("create", func(t *testing.T) {
|
||||
// NOTE: parallel is not set here to force parent
|
||||
// to wait for subtests to complete
|
||||
t.Run("unrestricted", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
path := filepath.Join(testDir, t.Name())
|
||||
|
||||
err := NewWindowsPaths().CreateDirectory(path, false)
|
||||
must.NoError(t, err)
|
||||
|
||||
dacl := getDirectoryDACL(t, path)
|
||||
|
||||
// When not applying restrictions on the new directory, all
|
||||
// ACEs will be inherited from the parent
|
||||
for i := range dacl.AceCount {
|
||||
ace := &windows.ACCESS_ALLOWED_ACE{}
|
||||
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
|
||||
must.Eq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
|
||||
must.Sprint("ACE is not inherited"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("restricted", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
path := filepath.Join(testDir, t.Name())
|
||||
|
||||
err := NewWindowsPaths().CreateDirectory(path, true)
|
||||
must.NoError(t, err)
|
||||
|
||||
dacl := getDirectoryDACL(t, path)
|
||||
matches := map[string]struct{}{}
|
||||
|
||||
// When restrictions are applied on the new directory, all
|
||||
// ACEs will be directly applied.
|
||||
for i := range dacl.AceCount {
|
||||
ace := &windows.ACCESS_ALLOWED_ACE{}
|
||||
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
|
||||
must.NotEq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
|
||||
must.Sprint("ACE should not be inherited"))
|
||||
|
||||
if ace.Mask&windows.GENERIC_ALL == windows.GENERIC_ALL {
|
||||
sid := (*windows.SID)(unsafe.Pointer(&ace.SidStart))
|
||||
matches[sid.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// All privileges should be set for user and administrators groups
|
||||
adminGroupSID, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
|
||||
must.NoError(t, err, must.Sprint("failed to create well known administrators group SID"))
|
||||
userSID, _, err := getUserGroupSIDs()
|
||||
must.NoError(t, err, must.Sprint("failed to get user SID"))
|
||||
|
||||
must.NotNil(t, matches[userSID.String()], must.Sprint("missing user ACE with GENERIC_ALL"))
|
||||
must.NotNil(t, matches[adminGroupSID.String()],
|
||||
must.Sprint("missing administrators group ACE with GENERIC_ALL"))
|
||||
|
||||
must.Eq(t, 2, len(matches), must.Sprint("unexpected GENERIC_ALL ACEs found"))
|
||||
})
|
||||
|
||||
t.Run("unrestricted already exists", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
path := filepath.Join(testDir, t.Name())
|
||||
must.NoError(t, os.MkdirAll(path, 0o000))
|
||||
|
||||
err := NewWindowsPaths().CreateDirectory(path, false)
|
||||
must.NoError(t, err)
|
||||
|
||||
dacl := getDirectoryDACL(t, path)
|
||||
|
||||
// No restrictions are applied, so check that all ACEs
|
||||
// are inherited from parent
|
||||
for i := range dacl.AceCount {
|
||||
ace := &windows.ACCESS_ALLOWED_ACE{}
|
||||
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
|
||||
must.Eq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
|
||||
must.Sprint("ACE is not inherited"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("restricted already exists", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
path := filepath.Join(testDir, t.Name())
|
||||
must.NoError(t, os.MkdirAll(path, 0o000))
|
||||
|
||||
err := NewWindowsPaths().CreateDirectory(path, true)
|
||||
must.NoError(t, err)
|
||||
|
||||
dacl := getDirectoryDACL(t, path)
|
||||
|
||||
// When the directory already exists, restrictions should not
|
||||
// be applied so validate that all ACEs are inherited
|
||||
for i := range dacl.AceCount {
|
||||
ace := &windows.ACCESS_ALLOWED_ACE{}
|
||||
must.NoError(t, windows.GetAce(dacl, uint32(i), &ace), must.Sprint("failed to load ACE"))
|
||||
must.Eq(t, windows.INHERITED_ACCESS_ENTRY, ace.Header.AceFlags&windows.INHERITED_ACCESS_ENTRY,
|
||||
must.Sprint("ACE is not inherited"))
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpand(t *testing.T) {
|
||||
t.Run("SystemDrive", func(t *testing.T) {
|
||||
t.Run("default", func(t *testing.T) {
|
||||
result, err := NewWindowsPaths().Expand(`{{.SystemDrive}}/testing`)
|
||||
must.NoError(t, err)
|
||||
must.StrNotContains(t, result, "{{.SystemDrive}}")
|
||||
})
|
||||
t.Run("custom environment variable", func(t *testing.T) {
|
||||
t.Setenv("SystemDrive", `z:`)
|
||||
result, err := NewWindowsPaths().Expand(`{{.SystemDrive}}\testing`)
|
||||
must.NoError(t, err)
|
||||
must.Eq(t, `z:\testing`, result)
|
||||
})
|
||||
t.Run("unset environment variable", func(t *testing.T) {
|
||||
t.Setenv("SystemDrive", "")
|
||||
_, err := NewWindowsPaths().Expand(`{{.SystemDrive}}\testing`)
|
||||
must.ErrorContains(t, err, "cannot detect Windows SystemDrive path")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ProgramData", func(t *testing.T) {
|
||||
t.Run("default", func(t *testing.T) {
|
||||
result, err := NewWindowsPaths().Expand(`{{.ProgramData}}/testing`)
|
||||
must.NoError(t, err)
|
||||
must.StrNotContains(t, result, "{{.ProgramData}}")
|
||||
})
|
||||
t.Run("custom environment variable", func(t *testing.T) {
|
||||
t.Setenv("ProgramData", `z:`)
|
||||
result, err := NewWindowsPaths().Expand(`{{.ProgramData}}\testing`)
|
||||
must.NoError(t, err)
|
||||
must.Eq(t, `z:\testing`, result)
|
||||
})
|
||||
t.Run("unset environment variable", func(t *testing.T) {
|
||||
t.Setenv("ProgramData", "")
|
||||
result, err := NewWindowsPaths().Expand(`{{.ProgramData}}\testing`)
|
||||
must.NoError(t, err)
|
||||
must.StrNotContains(t, result, "{{.ProgramData}}") // should be pulled from registry
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ProgramFiles", func(t *testing.T) {
|
||||
t.Run("default", func(t *testing.T) {
|
||||
result, err := NewWindowsPaths().Expand(`{{.ProgramFiles}}/testing`)
|
||||
must.NoError(t, err)
|
||||
must.StrNotContains(t, result, "{{.ProgramFiles}}")
|
||||
})
|
||||
t.Run("custom environment variable", func(t *testing.T) {
|
||||
t.Setenv("ProgramFiles", `z:`)
|
||||
result, err := NewWindowsPaths().Expand(`{{.ProgramFiles}}\testing`)
|
||||
must.NoError(t, err)
|
||||
must.Eq(t, `z:\testing`, result)
|
||||
})
|
||||
t.Run("unset environment variable", func(t *testing.T) {
|
||||
t.Setenv("ProgramFiles", "")
|
||||
result, err := NewWindowsPaths().Expand(`{{.ProgramFiles}}\testing`)
|
||||
must.NoError(t, err)
|
||||
must.StrNotContains(t, result, "{{.ProgramFiles}}") // should be pulled from registry
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
_, err := NewWindowsPaths().Expand(`{{.Unknown}}\testing`)
|
||||
must.ErrorContains(t, err, "can't evaluate field")
|
||||
})
|
||||
}
|
||||
|
||||
func getDirectoryDACL(t *testing.T, path string) *windows.ACL {
|
||||
t.Helper()
|
||||
|
||||
s, err := os.Stat(path)
|
||||
must.NoError(t, err)
|
||||
must.True(t, s.IsDir(), must.Sprint("expected path to be a directory"))
|
||||
|
||||
info, err := windows.GetNamedSecurityInfo(path,
|
||||
windows.SE_FILE_OBJECT, windows.DACL_SECURITY_INFORMATION)
|
||||
must.NoError(t, err, must.Sprint("failed to get path security information"))
|
||||
|
||||
dacl, _, err := info.DACL()
|
||||
must.NoError(t, err, must.Sprint("failed to get path ACL"))
|
||||
|
||||
return dacl
|
||||
}
|
||||
11
helper/winsvc/privileged_nonwindows.go
Normal file
11
helper/winsvc/privileged_nonwindows.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
//go:build !windows
|
||||
|
||||
package winsvc
|
||||
|
||||
// IsPrivilegedProcess checks if current process is a privileged windows process
|
||||
func IsPrivilegedProcess() bool {
|
||||
return false
|
||||
}
|
||||
11
helper/winsvc/privileged_windows.go
Normal file
11
helper/winsvc/privileged_windows.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package winsvc
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
// IsPrivilegedProcess checks if current process is a privileged windows process
|
||||
func IsPrivilegedProcess() bool {
|
||||
return windows.GetCurrentProcessToken().IsElevated()
|
||||
}
|
||||
@@ -3,6 +3,18 @@
|
||||
|
||||
package winsvc
|
||||
|
||||
const (
|
||||
WINDOWS_SERVICE_NAME = "nomad"
|
||||
WINDOWS_SERVICE_DISPLAY_NAME = "HashiCorp Nomad"
|
||||
WINDOWS_SERVICE_DESCRIPTION = "Workload scheduler and orchestrator - https://nomadproject.io"
|
||||
WINDOWS_INSTALL_BIN_DIRECTORY = `{{.ProgramFiles}}\HashiCorp\nomad\bin`
|
||||
WINDOWS_INSTALL_APPDATA_DIRECTORY = `{{.ProgramData}}\HashiCorp\nomad`
|
||||
|
||||
// Number of seconds to wait for a
|
||||
// service to reach a desired state
|
||||
WINDOWS_SERVICE_STATE_TIMEOUT = "1m"
|
||||
)
|
||||
|
||||
var chanGraceExit = make(chan int)
|
||||
|
||||
// ShutdownChannel returns a channel that sends a message that a shutdown
|
||||
|
||||
75
helper/winsvc/windows_service.go
Normal file
75
helper/winsvc/windows_service.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package winsvc
|
||||
|
||||
type ServiceStartType uint32
|
||||
|
||||
// extracted from https://pkg.go.dev/golang.org/x/sys@v0.35.0/windows/svc/mgr#StartManual
|
||||
const (
|
||||
StartManual ServiceStartType = 3
|
||||
StartAutomatic ServiceStartType = 2
|
||||
StartDisabled ServiceStartType = 4
|
||||
)
|
||||
|
||||
type WindowsServiceConfiguration struct {
|
||||
StartType ServiceStartType
|
||||
DisplayName string
|
||||
Description string
|
||||
BinaryPathName string
|
||||
}
|
||||
|
||||
type WindowsPaths interface {
|
||||
// Expand expands the path defined by the template. Supports
|
||||
// values for:
|
||||
// - SystemDrive
|
||||
// - SystemRoot
|
||||
// - ProgramData
|
||||
// - ProgramFiles
|
||||
Expand(path string) (string, error)
|
||||
|
||||
// Creates a new directory if it does not exist. If directory
|
||||
// is created and restrict_on_create is true, a restrictive
|
||||
// ACL is applied.
|
||||
CreateDirectory(path string, restrict_on_create bool) error
|
||||
}
|
||||
|
||||
type WindowsService interface {
|
||||
// Name returns the name of the service
|
||||
Name() string
|
||||
// Configure applies the configuration to the Windows service.
|
||||
// NOTE: Full configuration applied so empty values will remove existing values.
|
||||
Configure(config WindowsServiceConfiguration) error
|
||||
// Start starts the Windows service and waits for the
|
||||
// service to be running.
|
||||
Start() error
|
||||
// Stop requests the service to stop and waits for the
|
||||
// service to stop.
|
||||
Stop() error
|
||||
// Close closes the connection to the Windows service.
|
||||
Close() error
|
||||
// Delete deletes the Windows service.
|
||||
Delete() error
|
||||
// IsRunning returns if the service is currently running.
|
||||
IsRunning() (bool, error)
|
||||
// IsStopped returns if the service is currently stopped.
|
||||
IsStopped() (bool, error)
|
||||
// EnableEventlog will add or update the Windows Eventlog
|
||||
// configuration for the service. It will set supported
|
||||
// events as info, warning, and error.
|
||||
EnableEventlog() error
|
||||
// DisableEventlog will remove the Windows Eventlog configuration
|
||||
// for the service.
|
||||
DisableEventlog() error
|
||||
}
|
||||
|
||||
type WindowsServiceManager interface {
|
||||
// IsServiceRegistered returns if the service is a registered Windows service.
|
||||
IsServiceRegistered(name string) (bool, error)
|
||||
// GetService opens and returns the named service.
|
||||
GetService(name string) (WindowsService, error)
|
||||
// CreateService creates a new Windows service.
|
||||
CreateService(name, binaryPath string, config WindowsServiceConfiguration) (WindowsService, error)
|
||||
// Close closes Windows service manager connection.
|
||||
Close() error
|
||||
}
|
||||
15
helper/winsvc/windows_service_nonwindows.go
Normal file
15
helper/winsvc/windows_service_nonwindows.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
//go:build !windows
|
||||
|
||||
package winsvc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// NewWindowsServiceManager returns an error
|
||||
func NewWindowsServiceManager() (WindowsServiceManager, error) {
|
||||
return nil, errors.New("Windows service manager is not supported on this platform")
|
||||
}
|
||||
256
helper/winsvc/windows_service_windows.go
Normal file
256
helper/winsvc/windows_service_windows.go
Normal file
@@ -0,0 +1,256 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package winsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/signal"
|
||||
"reflect"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/nomad/helper"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.org/x/sys/windows/svc"
|
||||
"golang.org/x/sys/windows/svc/eventlog"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
// Base registry path for eventlog registrations
|
||||
const EVENTLOG_REGISTRY_PATH = `SYSTEM\CurrentControlSet\Services\EventLog\Application`
|
||||
|
||||
// Registry value name for supported event types
|
||||
const EVENTLOG_SUPPORTED_EVENTS_KEY = "TypesSupported"
|
||||
|
||||
// Event types registered as supported
|
||||
const EVENTLOG_SUPPORTED_EVENTS uint32 = eventlog.Error | eventlog.Warning | eventlog.Info
|
||||
|
||||
// NewWindowsServiceManager creates a new instance of the wrapper
|
||||
// to interact with the Windows service manager.
|
||||
func NewWindowsServiceManager() (WindowsServiceManager, error) {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &windowsServiceManager{manager: m}, nil
|
||||
}
|
||||
|
||||
type windowsServiceManager struct {
|
||||
manager *mgr.Mgr
|
||||
}
|
||||
|
||||
func (m *windowsServiceManager) IsServiceRegistered(name string) (bool, error) {
|
||||
list, err := m.manager.ListServices()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if slices.Contains(list, name) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *windowsServiceManager) GetService(name string) (WindowsService, error) {
|
||||
service, err := m.manager.OpenService(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &windowsService{service: service}, nil
|
||||
}
|
||||
|
||||
func (m *windowsServiceManager) CreateService(name, bin string, config WindowsServiceConfiguration) (WindowsService, error) {
|
||||
wsvc, err := m.manager.CreateService(name, bin, mgr.Config{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service := &windowsService{service: wsvc}
|
||||
|
||||
// Only apply configuration if configuration is provided
|
||||
if !reflect.ValueOf(config).IsZero() {
|
||||
if err := service.Configure(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *windowsServiceManager) Close() error {
|
||||
return m.manager.Disconnect()
|
||||
}
|
||||
|
||||
type windowsService struct {
|
||||
service *mgr.Service
|
||||
}
|
||||
|
||||
func (s *windowsService) Name() string {
|
||||
return s.service.Name
|
||||
}
|
||||
|
||||
func (s *windowsService) Configure(config WindowsServiceConfiguration) error {
|
||||
serviceCfg, err := s.service.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serviceCfg.StartType = uint32(config.StartType)
|
||||
serviceCfg.DisplayName = config.DisplayName
|
||||
serviceCfg.Description = config.Description
|
||||
serviceCfg.BinaryPathName = config.BinaryPathName
|
||||
|
||||
if err := s.service.UpdateConfig(serviceCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsService) Start() error {
|
||||
if running, _ := s.IsRunning(); running {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.service.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := waitFor(context.Background(), s.IsRunning); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsService) Stop() error {
|
||||
if stopped, _ := s.IsStopped(); stopped {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := s.service.Control(svc.Stop); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := waitFor(context.Background(), s.IsStopped); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsService) Close() error {
|
||||
return s.service.Close()
|
||||
}
|
||||
|
||||
func (s *windowsService) Delete() error {
|
||||
return s.service.Delete()
|
||||
}
|
||||
|
||||
func (s *windowsService) IsRunning() (bool, error) {
|
||||
return s.isService(svc.Running)
|
||||
}
|
||||
|
||||
func (s *windowsService) IsStopped() (bool, error) {
|
||||
return s.isService(svc.Stopped)
|
||||
}
|
||||
|
||||
func (s *windowsService) EnableEventlog() error {
|
||||
// Check if the service is already setup in the eventlog
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
EVENTLOG_REGISTRY_PATH+`\`+s.Name(),
|
||||
registry.ALL_ACCESS,
|
||||
)
|
||||
|
||||
// If it could not be opened, assume error is caused
|
||||
// due to nonexistence. If it was for some other reason
|
||||
// the error will be encountered again when attempting to
|
||||
// create.
|
||||
if err != nil {
|
||||
if err := eventlog.InstallAsEventCreate(s.Name(), EVENTLOG_SUPPORTED_EVENTS); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
defer key.Close()
|
||||
|
||||
// Since the service is already registered, just
|
||||
// ensure it is properly configured. Currently
|
||||
// that just means the supported events.
|
||||
val, _, err := key.GetIntegerValue(EVENTLOG_SUPPORTED_EVENTS_KEY)
|
||||
if err != nil || uint32(val) != EVENTLOG_SUPPORTED_EVENTS {
|
||||
if err := key.SetDWordValue(EVENTLOG_SUPPORTED_EVENTS_KEY, EVENTLOG_SUPPORTED_EVENTS); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *windowsService) DisableEventlog() error {
|
||||
// Check if the service is currently enabled in the eventlog
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
EVENTLOG_REGISTRY_PATH+`\`+s.Name(),
|
||||
registry.READ,
|
||||
)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
return eventlog.Remove(s.Name())
|
||||
}
|
||||
|
||||
func (s *windowsService) isService(state svc.State) (bool, error) {
|
||||
status, err := s.service.Query()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return status.State == state, nil
|
||||
}
|
||||
|
||||
func waitFor(ctx context.Context, condition func() (bool, error)) error {
|
||||
d, err := time.ParseDuration(WINDOWS_SERVICE_STATE_TIMEOUT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Setup a deadline
|
||||
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(d))
|
||||
defer cancel()
|
||||
// Watch for any interrupts
|
||||
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
|
||||
defer stop()
|
||||
|
||||
pauseDur := time.Millisecond * 250
|
||||
t, timerStop := helper.NewSafeTimer(pauseDur)
|
||||
defer timerStop()
|
||||
|
||||
for {
|
||||
t.Reset(pauseDur)
|
||||
|
||||
complete, err := condition()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if complete {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout exceeded waiting for process")
|
||||
case <-t.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
598
helper/winsvc/windows_service_windows_test.go
Normal file
598
helper/winsvc/windows_service_windows_test.go
Normal file
@@ -0,0 +1,598 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package winsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/fs"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/nomad/ci"
|
||||
"github.com/shoenig/test/must"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.org/x/sys/windows/svc"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
func TestWindowsServiceManager(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
t.Run("IsServiceRegistered", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("service does not exist", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
_, manager := makeManagers(t)
|
||||
|
||||
result, err := manager.IsServiceRegistered("fake-service-name")
|
||||
must.NoError(t, err, must.Sprint("check should not error"))
|
||||
must.False(t, result, must.Sprint("service should not exist"))
|
||||
})
|
||||
|
||||
t.Run("service does exist", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
result, err := manager.IsServiceRegistered(serviceName)
|
||||
must.NoError(t, err, must.Sprint("check should not error"))
|
||||
must.True(t, result, must.Sprint("service should exist"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("GetService", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("service does not exist", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
_, manager := makeManagers(t)
|
||||
_, err := manager.GetService("fake-service-name")
|
||||
must.ErrorContains(t, err, "specified service does not exist",
|
||||
must.Sprint("error should be generated when service does not exist"))
|
||||
})
|
||||
|
||||
t.Run("service does exist", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err)
|
||||
defer srv.Close()
|
||||
must.Eq(t, serviceName, srv.Name(), must.Sprint("service name does not match"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CreateService", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("service does not exist", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
serviceName := generateServiceName()
|
||||
m, manager := makeManagers(t)
|
||||
|
||||
srv, err := manager.CreateService(serviceName, `c:\stub`, WindowsServiceConfiguration{})
|
||||
must.NoError(t, err)
|
||||
defer srv.Close()
|
||||
defer deleteStubService(t, m, serviceName)
|
||||
|
||||
must.Eq(t, serviceName, srv.Name(), must.Sprint("new service name is incorrect"))
|
||||
})
|
||||
|
||||
t.Run("service does exist", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
_, err := manager.CreateService(serviceName, `c:\stub`, WindowsServiceConfiguration{})
|
||||
must.ErrorContains(t, err, "service already exists", must.Sprint("service creation should fail"))
|
||||
})
|
||||
|
||||
t.Run("with configuration", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateServiceName()
|
||||
srv, err := manager.CreateService(serviceName, `c:\stub`,
|
||||
WindowsServiceConfiguration{DisplayName: "testing service", StartType: StartDisabled})
|
||||
must.NoError(t, err, must.Sprint("service should be created"))
|
||||
defer srv.Close()
|
||||
defer deleteStubService(t, m, serviceName)
|
||||
|
||||
directSrv, err := m.OpenService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("direct service connection should succeed"))
|
||||
defer directSrv.Close()
|
||||
|
||||
config, err := directSrv.Config()
|
||||
must.NoError(t, err, must.Sprint("configuration should be available from service"))
|
||||
must.Eq(t, "testing service", config.DisplayName, must.Sprint("new service name does not match"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// This is a simple service available in Windows. It will
|
||||
// be used to locate the executable so a test service can
|
||||
// be created using it that will allow proper start/stop
|
||||
// testing.
|
||||
const TEST_WINDOWS_SERVICE = "SNMPTrap"
|
||||
|
||||
func TestWindowsService(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
mg, _ := makeManagers(t)
|
||||
snmpSvc, err := mg.OpenService(TEST_WINDOWS_SERVICE)
|
||||
must.NoError(t, err)
|
||||
defer snmpSvc.Close()
|
||||
snmpConfig, err := snmpSvc.Config()
|
||||
must.NoError(t, err)
|
||||
binPath := snmpConfig.BinaryPathName
|
||||
|
||||
t.Run("Name", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err)
|
||||
defer srv.Close()
|
||||
|
||||
must.Eq(t, serviceName, srv.Name(), must.Sprint("service name does not match"))
|
||||
})
|
||||
|
||||
t.Run("Configure", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("valid configuration", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
err = srv.Configure(WindowsServiceConfiguration{
|
||||
StartType: StartDisabled,
|
||||
DisplayName: "testing display name",
|
||||
BinaryPathName: `c:\stub -with -arguments`,
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("valid configuration should not error"))
|
||||
directSrv, err := m.OpenService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
config, err := directSrv.Config()
|
||||
must.NoError(t, err, must.Sprint("direct service config should be available"))
|
||||
must.Eq(t, "testing display name", config.DisplayName, must.Sprint("display name does not match"))
|
||||
must.Eq(t, `c:\stub -with -arguments`, config.BinaryPathName, must.Sprint("binary path name does not match"))
|
||||
})
|
||||
|
||||
t.Run("invalid configuration", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
srv, err := manager.GetService(serviceName)
|
||||
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
err = srv.Configure(WindowsServiceConfiguration{
|
||||
DisplayName: "testing display name",
|
||||
BinaryPathName: `c:\stub -with -arguments`,
|
||||
})
|
||||
must.ErrorContains(t, err, "parameter is incorrect", must.Sprint("invalid configuration should error"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Start", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("when stopped", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Stopped {
|
||||
_, err := directSrv.Control(svc.Stop)
|
||||
must.NoError(t, err, must.Sprint("direct stop should not fail"))
|
||||
err = waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Stopped, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be stopped"))
|
||||
}
|
||||
must.NoError(t, runnableSvc.Start(), must.Sprint("service should start without error"))
|
||||
status, err = directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
must.Eq(t, status.State, svc.Running, must.Sprint("service should be running"))
|
||||
})
|
||||
|
||||
t.Run("when running", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Running {
|
||||
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
|
||||
err := waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Running, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be running"))
|
||||
}
|
||||
must.NoError(t, runnableSvc.Start(), must.Sprint("service should start without error"))
|
||||
status, err = directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
must.Eq(t, status.State, svc.Running, must.Sprint("service should be running"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Stop", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("when stopped", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Stopped {
|
||||
_, err := directSrv.Control(svc.Stop)
|
||||
must.NoError(t, err, must.Sprint("direct stop should not fail"))
|
||||
err = waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Stopped, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be stopped"))
|
||||
}
|
||||
must.NoError(t, runnableSvc.Stop(), must.Sprint("service should stop without error"))
|
||||
status, err = directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
must.Eq(t, status.State, svc.Stopped, must.Sprint("service should be stopped"))
|
||||
})
|
||||
|
||||
t.Run("when running", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Running {
|
||||
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
|
||||
err := waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Running, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be running"))
|
||||
}
|
||||
must.NoError(t, runnableSvc.Stop(), must.Sprint("service should stop without error"))
|
||||
status, err = directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
must.Eq(t, status.State, svc.Stopped, must.Sprint("service should be stopped"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("when service exists", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
|
||||
serviceName := generateStubService(t, m)
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("service should be avaialble"))
|
||||
defer srv.Close()
|
||||
|
||||
must.NoError(t, srv.Delete(), must.Sprint("service should be deleted"))
|
||||
})
|
||||
|
||||
t.Run("when service does not exist", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
|
||||
serviceName := generateStubService(t, m)
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("service should be avaialble"))
|
||||
defer srv.Close()
|
||||
// Delete the service directly
|
||||
directSrv, err := m.OpenService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
must.NoError(t, directSrv.Delete(), must.Sprint("service should be deleted"))
|
||||
|
||||
must.ErrorContains(t, srv.Delete(), "marked for deletion",
|
||||
must.Sprint("service should have already been deleted"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("IsRunning", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("when service is not running", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Stopped {
|
||||
_, err := directSrv.Control(svc.Stop)
|
||||
must.NoError(t, err, must.Sprint("direct stop should not fail"))
|
||||
err = waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Stopped, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be stopped"))
|
||||
}
|
||||
|
||||
srv, err := manager.GetService(directSrv.Name)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
result, err := srv.IsRunning()
|
||||
must.NoError(t, err, must.Sprint("running check should not error"))
|
||||
must.False(t, result, must.Sprint("should not show service as running"))
|
||||
})
|
||||
|
||||
t.Run("when service is running", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Running {
|
||||
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
|
||||
err := waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Running, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be running"))
|
||||
}
|
||||
srv, err := manager.GetService(directSrv.Name)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
result, err := srv.IsRunning()
|
||||
must.NoError(t, err, must.Sprint("running check should not error"))
|
||||
must.True(t, result, must.Sprint("should show service as running"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("IsStopped", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("when service is not running", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Stopped {
|
||||
_, err := directSrv.Control(svc.Stop)
|
||||
must.NoError(t, err, must.Sprint("direct stop should not fail"))
|
||||
err = waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Stopped, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be stopped"))
|
||||
}
|
||||
|
||||
srv, err := manager.GetService(directSrv.Name)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
result, err := srv.IsStopped()
|
||||
must.NoError(t, err, must.Sprint("running check should not error"))
|
||||
must.True(t, result, must.Sprint("should show service as stopped"))
|
||||
})
|
||||
|
||||
t.Run("when service is running", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
runnableSvc := runnableService(t, manager, binPath)
|
||||
directSrv, err := m.OpenService(runnableSvc.Name())
|
||||
must.NoError(t, err, must.Sprint("direct service should be available"))
|
||||
defer directSrv.Close()
|
||||
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service status should be available"))
|
||||
if status.State != svc.Running {
|
||||
must.NoError(t, directSrv.Start(), must.Sprint("direct start should not fail"))
|
||||
err := waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := directSrv.Query()
|
||||
must.NoError(t, err, must.Sprint("direct service should be queryable"))
|
||||
return status.State == svc.Running, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprint("service must be running"))
|
||||
}
|
||||
srv, err := manager.GetService(directSrv.Name)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
result, err := srv.IsStopped()
|
||||
must.NoError(t, err, must.Sprint("running check should not error"))
|
||||
must.False(t, result, must.Sprint("should not show service as stopped"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("EnableEventLog", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("when service is not registered", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
|
||||
must.NoError(t, srv.EnableEventlog(), must.Sprint("could not enable eventlog"))
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
EVENTLOG_REGISTRY_PATH+`\`+serviceName,
|
||||
registry.READ,
|
||||
)
|
||||
must.NoError(t, err, must.Sprint("registry key should be available"))
|
||||
defer key.Close()
|
||||
val, _, err := key.GetIntegerValue(EVENTLOG_SUPPORTED_EVENTS_KEY)
|
||||
must.NoError(t, err, must.Sprint("registry key value should be available"))
|
||||
must.Eq(t, EVENTLOG_SUPPORTED_EVENTS, uint32(val), must.Sprint("registry value should match"))
|
||||
})
|
||||
|
||||
t.Run("when service is already registered", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
must.NoError(t, srv.EnableEventlog(), must.Sprint("could not enable eventlog"))
|
||||
// Modify value in registry
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
EVENTLOG_REGISTRY_PATH+`\`+serviceName,
|
||||
registry.ALL_ACCESS,
|
||||
)
|
||||
err = key.SetDWordValue(EVENTLOG_SUPPORTED_EVENTS_KEY, 1)
|
||||
must.NoError(t, err, must.Sprint("could not modify registry value"))
|
||||
|
||||
// Now enable and verify value is correct
|
||||
must.NoError(t, srv.EnableEventlog(), must.Sprint("failed to enable eventlog"))
|
||||
val, _, err := key.GetIntegerValue(EVENTLOG_SUPPORTED_EVENTS_KEY)
|
||||
must.NoError(t, err, must.Sprint("registry value should be available"))
|
||||
must.Eq(t, EVENTLOG_SUPPORTED_EVENTS, uint32(val), must.Sprint("registry value should match"))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DisableEventLog", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
t.Run("when service is not registered", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
|
||||
must.NoError(t, srv.DisableEventlog(), must.Sprint("eventlog disable should not error"))
|
||||
})
|
||||
|
||||
t.Run("when service is registered", func(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
m, manager := makeManagers(t)
|
||||
serviceName := generateStubService(t, m)
|
||||
|
||||
srv, err := manager.GetService(serviceName)
|
||||
must.NoError(t, err, must.Sprint("service should be available"))
|
||||
defer srv.Close()
|
||||
must.NoError(t, srv.EnableEventlog(), must.Sprint("eventlog enable should not error"))
|
||||
|
||||
must.NoError(t, srv.DisableEventlog(), must.Sprint("eventlog disable should not error"))
|
||||
_, err = registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
EVENTLOG_REGISTRY_PATH+`\`+serviceName,
|
||||
registry.READ,
|
||||
)
|
||||
must.ErrorIs(t, err, fs.ErrNotExist, must.Sprint("registry key should no longer exist"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func generateServiceName() string {
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id[:5]
|
||||
}
|
||||
|
||||
func generateStubService(t *testing.T, m *mgr.Mgr) string {
|
||||
t.Helper()
|
||||
|
||||
id := generateServiceName()
|
||||
_, err := m.CreateService(id, `c:\stub`, mgr.Config{})
|
||||
must.NoError(t, err, must.Sprint("failed to generate stub service"))
|
||||
|
||||
t.Cleanup(func() { deleteStubService(t, m, id) })
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
func deleteStubService(t *testing.T, m *mgr.Mgr, svcId string) {
|
||||
t.Helper()
|
||||
|
||||
srvc, err := m.OpenService(svcId)
|
||||
if err != nil {
|
||||
// If the service doesn't exist, then deletion is done so not
|
||||
// an error. Otherwise, force an error.
|
||||
must.ErrorContains(t, err, "service does not exist", must.Sprint("failed to open service"))
|
||||
return
|
||||
}
|
||||
status, err := srvc.Query()
|
||||
must.NoError(t, err, must.Sprint("failed to query service"))
|
||||
if status.State != svc.Stopped {
|
||||
status, err = srvc.Control(svc.Stop)
|
||||
must.NoError(t, err, must.Sprint("failed to stop service"))
|
||||
err := waitFor(context.Background(), func() (bool, error) {
|
||||
status, err := srvc.Query()
|
||||
must.NoError(t, err, must.Sprint("failed to query service"))
|
||||
return status.State == svc.Stopped, nil
|
||||
})
|
||||
must.NoError(t, err, must.Sprintf("could not stop service for deletion - %s", svcId))
|
||||
}
|
||||
if err := srvc.Delete(); err != nil {
|
||||
must.ErrorContains(t, err, "service has been marked for deletion", must.Sprint("failed to delete service"))
|
||||
}
|
||||
}
|
||||
|
||||
func makeManagers(t *testing.T) (*mgr.Mgr, WindowsServiceManager) {
|
||||
t.Helper()
|
||||
|
||||
winM, err := NewWindowsServiceManager()
|
||||
must.NoError(t, err, must.Sprint("failed to create service manager"))
|
||||
|
||||
m, err := mgr.Connect()
|
||||
must.NoError(t, err, must.Sprint("failed to connect to windows service manager"))
|
||||
|
||||
t.Cleanup(func() {
|
||||
winM.Close()
|
||||
m.Disconnect()
|
||||
})
|
||||
|
||||
return m, winM
|
||||
}
|
||||
|
||||
func runnableService(t *testing.T, m WindowsServiceManager, binPath string) WindowsService {
|
||||
t.Helper()
|
||||
|
||||
runnableSvc, err := m.CreateService(generateServiceName(), binPath,
|
||||
WindowsServiceConfiguration{StartType: StartManual, BinaryPathName: binPath})
|
||||
must.NoError(t, err, must.Sprint("failed to create runnable service"))
|
||||
|
||||
t.Cleanup(func() { runnableSvc.Close() })
|
||||
|
||||
return runnableSvc
|
||||
}
|
||||
Reference in New Issue
Block a user