Update provider interface to get VM by name
The following discussion from !48 (merged) should be addressed:
-
@steveazz started a discussion: I don't think it's a good idea to have a pointer here, this allows us to do mutations which is what we are doing. I feel like this can lead to data race problems, and having things update it when we don't want it to/expect. At the moment we are only using
name
to find the VM, I get that is specific to GCP but at the moment that is the only provider we have. What do you think if we just have the following:Get(ctx context.Context, name string) config.VirtualMachine error
It also seems cleaner from an API level. If we implement another provider and it would require something else apart from the name we can decide at that time what needs to be done since we would have the context we need.
We can even refactor
getVM
to do this, I've done this in https://gitlab.com/gitlab-org/ci-cd/custom-executors/autoscaler/merge_requests/58/diffsfull diff
diff --git a/cmd/autoscaler/commands/custom/cleanup.go b/cmd/autoscaler/commands/custom/cleanup.go index 41e3016..d53ef74 100644 --- a/cmd/autoscaler/commands/custom/cleanup.go +++ b/cmd/autoscaler/commands/custom/cleanup.go @@ -51,8 +51,7 @@ func (c *CleanupCommand) CustomExecute(ctx *cli.Context) error { c.logger.Info("Executing the command") - vmCfg := config.VirtualMachine{Name: c.vmName} - err = c.provider.Get(ctx.Ctx, &vmCfg) + vmCfg, err := c.provider.Get(ctx.Ctx, c.vmName) if err != nil { return fmt.Errorf("couldn't get the VM %q details: %w", c.vmName, err) } diff --git a/cmd/autoscaler/commands/custom/run.go b/cmd/autoscaler/commands/custom/run.go index 3e5c30e..9749c0e 100644 --- a/cmd/autoscaler/commands/custom/run.go +++ b/cmd/autoscaler/commands/custom/run.go @@ -68,8 +68,7 @@ func (c *RunCommand) CustomExecute(ctx *cli.Context) error { args := ctx.Cli.Args() scriptPath := args.Get(0) - vmCfg := config.VirtualMachine{Name: c.vmName} - err = c.provider.Get(ctx.Ctx, &vmCfg) + vmCfg, err := c.provider.Get(ctx.Ctx, c.vmName) if err != nil { return fmt.Errorf("couldn't get the VM %q details: %w", c.vmName, err) } diff --git a/providers/gcp/provider.go b/providers/gcp/provider.go index ab47189..0084b00 100644 --- a/providers/gcp/provider.go +++ b/providers/gcp/provider.go @@ -68,23 +68,30 @@ func New(cfg globalConfig.Global, logger logging.Logger) (providers.Provider, er return p, nil } -func (p *Provider) Get(ctx context.Context, vm *globalConfig.VirtualMachine) error { +func (p *Provider) Get(ctx context.Context, name string) (globalConfig.VirtualMachine, error) { + vm := globalConfig.VirtualMachine{ + Name: name, + GCP: config.VirtualMachine{ + Project: p.config.Project, + Zone: p.config.Zone, + }, + } vm.GCP.Project = p.config.Project vm.GCP.Zone = p.config.Zone - instance, err := p.getVM(ctx, *vm) + instance, err := p.getVM(ctx, name) if err != nil { - return fmt.Errorf("couldn't get the instance details: %w", err) + return globalConfig.VirtualMachine{}, fmt.Errorf("couldn't get the instance details: %w", err) } vm.IPAddress = instance.NetworkInterfaces[0].NetworkIP - err = p.loadCredentialsFromMetadata(ctx, vm, instance) + err = p.loadCredentialsFromMetadata(ctx, &vm, instance) if err != nil { p.logger.WithError(err).Warn("couldn't load the credentials") } - return nil + return vm, nil } type metadataCredentials struct { @@ -195,7 +202,7 @@ func (p *Provider) Create(ctx context.Context, e executors.Executor, vm *globalC vm.GCP.Project = p.config.Project vm.GCP.Zone = p.config.Zone - instance, err := p.getVM(ctx, *vm) + instance, err := p.getVM(ctx, vm.Name) if err != nil { return fmt.Errorf("couldn't get the instance details: %w", err) } @@ -304,17 +311,17 @@ func (p *Provider) waitForOperation(logger logging.Logger, operation string, nam return err } -func (p *Provider) getVM(ctx context.Context, vm globalConfig.VirtualMachine) (*compute.Instance, error) { +func (p *Provider) getVM(ctx context.Context, name string) (*compute.Instance, error) { s, err := p.getInstancesService(ctx) if err != nil { return nil, err } - return s.Get(p.config.Project, p.config.Zone, vm.Name).Do() + return s.Get(p.config.Project, p.config.Zone, name).Do() } func (p *Provider) saveCredentialsInMetadata(ctx context.Context, logger logging.Logger, vm globalConfig.VirtualMachine) error { - instance, err := p.getVM(ctx, vm) + instance, err := p.getVM(ctx, vm.Name) if err != nil { return fmt.Errorf("couldn't get the instance details: %w", err) } diff --git a/providers/provider.go b/providers/provider.go index 3d73f69..6dd8bd2 100644 --- a/providers/provider.go +++ b/providers/provider.go @@ -12,7 +12,7 @@ import ( const factoryType = "provider" type Provider interface { - Get(ctx context.Context, vm *config.VirtualMachine) error + Get(ctx context.Context, name string) (config.VirtualMachine, error) Create(ctx context.Context, e executors.Executor, vm *config.VirtualMachine) error Delete(ctx context.Context, vm config.VirtualMachine) error }