rpc: Move register args initial validation into separate function. (#26446)

The RPC handler function is quite long, so moving the argument
validation into its own function reduces this and makes sense from
an organisation view.
This commit is contained in:
James Rasell
2025-08-08 14:47:27 +02:00
committed by GitHub
parent b6f90d0562
commit f5c02671e5
3 changed files with 186 additions and 27 deletions

View File

@@ -125,33 +125,9 @@ func (n *Node) Register(args *structs.NodeRegisterRequest, reply *structs.NodeUp
defer metrics.MeasureSince([]string{"nomad", "client", "register"}, time.Now())
// Validate the arguments
if args.Node == nil {
return fmt.Errorf("missing node for client registration")
}
if args.Node.ID == "" {
return fmt.Errorf("missing node ID for client registration")
}
if args.Node.Datacenter == "" {
return fmt.Errorf("missing datacenter for client registration")
}
if args.Node.Name == "" {
return fmt.Errorf("missing node name for client registration")
}
if len(args.Node.Attributes) == 0 {
return fmt.Errorf("missing attributes for client registration")
}
if args.Node.SecretID == "" {
return fmt.Errorf("missing node secret ID for client registration")
}
if args.Node.NodePool != "" {
err := structs.ValidateNodePoolName(args.Node.NodePool)
if err != nil {
return fmt.Errorf("invalid node pool: %v", err)
}
if args.Node.NodePool == structs.NodePoolAll {
return fmt.Errorf("node is not allowed to register in node pool %q", structs.NodePoolAll)
}
// Perform validation of the base provided request.
if err := args.Validate(); err != nil {
return err
}
// Default the status if none is given

View File

@@ -552,6 +552,40 @@ type NodeRegisterRequest struct {
WriteRequest
}
// Validate checks that the NodeRegisterRequest is valid. Any returned error can
// be sent back to the client as a response to the RPC call.
func (n *NodeRegisterRequest) Validate() error {
if n.Node == nil {
return errors.New("missing node for client registration")
}
if n.Node.ID == "" {
return errors.New("missing node ID for client registration")
}
if n.Node.Datacenter == "" {
return errors.New("missing datacenter for client registration")
}
if n.Node.Name == "" {
return errors.New("missing node name for client registration")
}
if len(n.Node.Attributes) == 0 {
return errors.New("missing attributes for client registration")
}
if n.Node.SecretID == "" {
return errors.New("missing node secret ID for client registration")
}
if n.Node.NodePool != "" {
if err := ValidateNodePoolName(n.Node.NodePool); err != nil {
return fmt.Errorf("invalid node pool: %v", err)
}
if n.Node.NodePool == NodePoolAll {
return fmt.Errorf("node is not allowed to register in node pool %q", NodePoolAll)
}
}
return nil
}
// ShouldGenerateNodeIdentity compliments the functionality within
// AuthenticateNodeIdentityGenerator to determine whether a new node identity
// should be generated within the RPC handler.

View File

@@ -282,6 +282,155 @@ func TestGenerateNodeIdentityClaims(t *testing.T) {
must.NotNil(t, claims.Expiry)
}
func TestNodeRegisterRequest_Validate(t *testing.T) {
ci.Parallel(t)
testCases := []struct {
name string
request *NodeRegisterRequest
errorContains string
}{
{
name: "valid",
request: &NodeRegisterRequest{
Node: &Node{
ID: "node-id",
SecretID: "node-secret-id",
Name: "node-name",
NodePool: "node-pool",
NodeClass: "node-class",
Datacenter: "node-datacenter",
Attributes: map[string]string{"key1": "value1"},
},
},
errorContains: "",
},
{
name: "nil node",
request: &NodeRegisterRequest{
Node: nil,
},
errorContains: "missing node for client registration",
},
{
name: "missing ID",
request: &NodeRegisterRequest{
Node: &Node{
ID: "",
SecretID: "node-secret-id",
Name: "node-name",
NodePool: "node-pool",
NodeClass: "node-class",
Datacenter: "node-datacenter",
Attributes: map[string]string{"key1": "value1"},
},
},
errorContains: "missing node ID for client registration",
},
{
name: "missing datacenter",
request: &NodeRegisterRequest{
Node: &Node{
ID: "node-id",
SecretID: "node-secret-id",
Name: "node-name",
NodePool: "node-pool",
NodeClass: "node-class",
Datacenter: "",
Attributes: map[string]string{"key1": "value1"},
},
},
errorContains: "missing datacenter for client registration",
},
{
name: "missing name",
request: &NodeRegisterRequest{
Node: &Node{
ID: "node-id",
SecretID: "node-secret-id",
Name: "",
NodePool: "node-pool",
NodeClass: "node-class",
Datacenter: "node-datacenter",
Attributes: map[string]string{"key1": "value1"},
},
},
errorContains: "missing node name for client registration",
},
{
name: "missing attributes",
request: &NodeRegisterRequest{
Node: &Node{
ID: "node-id",
SecretID: "node-secret-id",
Name: "node-name",
NodePool: "node-pool",
NodeClass: "node-class",
Datacenter: "node-datacenter",
Attributes: map[string]string{},
},
},
errorContains: "missing attributes for client registration",
},
{
name: "missing secret ID",
request: &NodeRegisterRequest{
Node: &Node{
ID: "node-id",
SecretID: "",
Name: "node-name",
NodePool: "node-pool",
NodeClass: "node-class",
Datacenter: "node-datacenter",
Attributes: map[string]string{"key1": "value1"},
},
},
errorContains: "missing node secret ID for client registration",
},
{
name: "invalid node pool name",
request: &NodeRegisterRequest{
Node: &Node{
ID: "node-id",
SecretID: "node-secret-id",
Name: "node-name",
NodePool: "****",
NodeClass: "node-class",
Datacenter: "node-datacenter",
Attributes: map[string]string{"key1": "value1"},
},
},
errorContains: `invalid node pool: invalid name "****"`,
},
{
name: "invalid node pool all use",
request: &NodeRegisterRequest{
Node: &Node{
ID: "node-id",
SecretID: "node-secret-id",
Name: "node-name",
NodePool: NodePoolAll,
NodeClass: "node-class",
Datacenter: "node-datacenter",
Attributes: map[string]string{"key1": "value1"},
},
},
errorContains: `node is not allowed to register in node pool "all"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actualError := tc.request.Validate()
if tc.errorContains != "" {
must.ErrorContains(t, actualError, tc.errorContains)
} else {
must.NoError(t, actualError)
}
})
}
}
func TestNodeRegisterRequest_ShouldGenerateNodeIdentity(t *testing.T) {
ci.Parallel(t)