Skip to content

Commit 7c8e5c9

Browse files
committed
Add device flow
1 parent 3f69989 commit 7c8e5c9

File tree

5 files changed

+427
-75
lines changed

5 files changed

+427
-75
lines changed

cli.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,20 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
151151
var pollInterval string
152152
var interval int
153153
var state string
154+
var userCode string
154155
var listener net.Listener
155156

156157
if secret != nil {
157158
pollInterval, _ = secret.Data["poll_interval"].(string)
158159
state, _ = secret.Data["state"].(string)
160+
userCode, _ = secret.Data["user_code"].(string)
159161
}
160-
if callbackMode == "direct" {
162+
if callbackMode != "client" {
161163
if state == "" {
162-
return nil, errors.New("no state returned in direct callback mode")
164+
return nil, errors.New("no state returned in " + callbackMode + " callback mode")
163165
}
164166
if pollInterval == "" {
165-
return nil, errors.New("no poll_interval returned in direct callback mode")
167+
return nil, errors.New("no poll_interval returned in " + callbackMode + " callback mode")
166168
}
167169
interval, err = strconv.Atoi(pollInterval)
168170
if err != nil {
@@ -218,6 +220,31 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
218220
// authorization is pending, try again
219221
}
220222
}
223+
if userCode != "" {
224+
fmt.Fprintf(os.Stderr, "When prompted, enter code %s\n\n", userCode)
225+
}
226+
227+
if callbackMode != "client" {
228+
data := map[string]interface{}{
229+
"state": state,
230+
"client_nonce": clientNonce,
231+
}
232+
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
233+
for {
234+
time.Sleep(time.Duration(interval) * time.Second)
235+
236+
secret, err := c.Logical().Write(pollUrl, data)
237+
if err == nil {
238+
return secret, nil
239+
}
240+
if strings.HasSuffix(err.Error(), "slow_down") {
241+
interval *= 2
242+
} else if !strings.HasSuffix(err.Error(), "authorization_pending") {
243+
return nil, err
244+
}
245+
// authorization is pending, try again
246+
}
247+
}
221248

222249
// Start local server
223250
go func() {
@@ -376,8 +403,9 @@ Configuration:
376403
Vault role of type "OIDC" to use for authentication.
377404
378405
%s=<string>
379-
Mode of callback: "direct" for direct connection to Vault or "client"
380-
for connection to command line client (default: client).
406+
Mode of callback: "direct" for direct connection to Vault, "client"
407+
for connection to command line client, or "device" for device flow
408+
which has no callback (default: client).
381409
382410
%s=<string>
383411
Optional address to bind the OIDC callback listener to in client callback

path_config.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ import (
88
"crypto"
99
"crypto/tls"
1010
"crypto/x509"
11+
"encoding/json"
1112
"errors"
13+
"fmt"
14+
"io/ioutil"
1215
"net/http"
16+
"net/url"
1317
"strings"
1418

1519
"github.com/hashicorp/cap/jwt"
@@ -163,6 +167,91 @@ func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtCon
163167
return config, nil
164168
}
165169

170+
func contactIssuer(ctx context.Context, uri string, data *url.Values, ignoreBad bool) ([]byte, error) {
171+
var req *http.Request
172+
var err error
173+
if data == nil {
174+
req, err = http.NewRequest("GET", uri, nil)
175+
} else {
176+
req, err = http.NewRequest("POST", uri, strings.NewReader(data.Encode()))
177+
}
178+
if err != nil {
179+
return nil, nil
180+
}
181+
if data != nil {
182+
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
183+
}
184+
185+
client, ok := ctx.Value(oauth2.HTTPClient).(*http.Client)
186+
if !ok {
187+
client = http.DefaultClient
188+
}
189+
resp, err := client.Do(req.WithContext(ctx))
190+
if err != nil {
191+
return nil, nil
192+
}
193+
defer resp.Body.Close()
194+
195+
body, err := ioutil.ReadAll(resp.Body)
196+
if err != nil {
197+
return nil, nil
198+
}
199+
200+
if resp.StatusCode != http.StatusOK && (!ignoreBad || resp.StatusCode != http.StatusBadRequest) {
201+
return nil, fmt.Errorf("%s: %s", resp.Status, body)
202+
}
203+
204+
return body, nil
205+
}
206+
207+
// Discover the device_authorization_endpoint URL and store it in the config
208+
// This should be in coreos/go-oidc but they don't yet support device flow
209+
// At the same time, look up token_endpoint and store it as well
210+
// Returns nil on success, otherwise returns an error
211+
func (b *jwtAuthBackend) configDeviceAuthURL(ctx context.Context, s logical.Storage) (error) {
212+
config, err := b.config(ctx, s)
213+
if err != nil {
214+
return err
215+
}
216+
217+
b.l.Lock()
218+
defer b.l.Unlock()
219+
220+
if config.OIDCDeviceAuthURL != "" {
221+
if config.OIDCDeviceAuthURL == "N/A" {
222+
return fmt.Errorf("no device auth endpoint url discovered")
223+
}
224+
return nil
225+
}
226+
227+
caCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM)
228+
if err != nil {
229+
return errwrap.Wrapf("error creating context for device auth: {{err}}", err)
230+
}
231+
232+
issuer := config.OIDCDiscoveryURL
233+
234+
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
235+
body, err := contactIssuer(caCtx, wellKnown, nil, false)
236+
if err != nil {
237+
return errwrap.Wrapf("error reading issuer config: {{err}}", err)
238+
}
239+
240+
var daj struct {
241+
DeviceAuthURL string `json:"device_authorization_endpoint"`
242+
TokenURL string `json:"token_endpoint"`
243+
}
244+
err = json.Unmarshal(body, &daj)
245+
if err != nil || daj.DeviceAuthURL == "" {
246+
b.cachedConfig.OIDCDeviceAuthURL = "N/A"
247+
return fmt.Errorf("no device auth endpoint url discovered")
248+
}
249+
250+
b.cachedConfig.OIDCDeviceAuthURL = daj.DeviceAuthURL
251+
b.cachedConfig.OIDCTokenURL = daj.TokenURL
252+
return nil
253+
}
254+
166255
func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
167256
config, err := b.config(ctx, req.Storage)
168257
if err != nil {
@@ -420,6 +509,9 @@ type jwtConfig struct {
420509
NamespaceInState bool `json:"namespace_in_state"`
421510

422511
ParsedJWTPubKeys []crypto.PublicKey `json:"-"`
512+
// These are looked up from OIDCDiscoveryURL when needed
513+
OIDCDeviceAuthURL string `json:"-"`
514+
OIDCTokenURL string `json:"-"`
423515
}
424516

425517
const (

0 commit comments

Comments
 (0)