package auth
import (
"errors"
"fmt"
"log"
"net/url"
"regexp"
"time"
"github.com/nsqio/nsq/internal/http_api"
)
type Authorization struct {
Topic string `json:"topic"`
Channels []string `json:"channels"`
Permissions []string `json:"permissions"`
}
type State struct {
TTL int `json:"ttl"`
Authorizations []Authorization `json:"authorizations"`
Identity string `json:"identity"`
IdentityURL string `json:"identity_url"`
Expires time.Time
}
func (a *Authorization) HasPermission(permission string) bool {
for _, p := range a.Permissions {
if permission == p {
return true
}
}
return false
}
func (a *Authorization) IsAllowed(topic, channel string) bool {
if channel != "" {
if !a.HasPermission("subscribe") {
return false
}
} else {
if !a.HasPermission("publish") {
return false
}
}
topicRegex := regexp.MustCompile(a.Topic)
if !topicRegex.MatchString(topic) {
return false
}
for _, c := range a.Channels {
channelRegex := regexp.MustCompile(c)
if channelRegex.MatchString(channel) {
return true
}
}
return false
}
func (a *State) IsAllowed(topic, channel string) bool {
for _, aa := range a.Authorizations {
if aa.IsAllowed(topic, channel) {
return true
}
}
return false
}
func (a *State) IsExpired() bool {
if a.Expires.Before(time.Now()) {
return true
}
return false
}
func QueryAnyAuthd(authd []string, remoteIP, tlsEnabled, authSecret string,
connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
for _, a := range authd {
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, authSecret, connectTimeout, requestTimeout)
if err != nil {
log.Printf("Error: failed auth against %s %s", a, err)
continue
}
return authState, nil
}
return nil, errors.New("Unable to access auth server")
}
func QueryAuthd(authd, remoteIP, tlsEnabled, authSecret string,
connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
v := url.Values{}
v.Set("remote_ip", remoteIP)
v.Set("tls", tlsEnabled)
v.Set("secret", authSecret)
endpoint := fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
var authState State
client := http_api.NewClient(nil, connectTimeout, requestTimeout)
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
}
// validation on response
for _, auth := range authState.Authorizations {
for _, p := range auth.Permissions {
switch p {
case "subscribe", "publish":
default:
return nil, fmt.Errorf("unknown permission %s", p)
}
}
if _, err := regexp.Compile(auth.Topic); err != nil {
return nil, fmt.Errorf("unable to compile topic %q %s", auth.Topic, err)
}
for _, channel := range auth.Channels {
if _, err := regexp.Compile(channel); err != nil {
return nil, fmt.Errorf("unable to compile channel %q %s", channel, err)
}
}
}
if authState.TTL <= 0 {
return nil, fmt.Errorf("invalid TTL %d (must be >0)", authState.TTL)
}
authState.Expires = time.Now().Add(time.Duration(authState.TTL) * time.Second)
return &authState, nil
}