5 Commits

6 changed files with 81 additions and 136 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
redis redis
sus
.env .env

View File

@@ -1,4 +1,4 @@
FROM golang:1.18 FROM golang:1.16
WORKDIR /go/src/app WORKDIR /go/src/app
COPY . . COPY . .

View File

@@ -6,7 +6,10 @@ services:
ports: [ "8430:80" ] ports: [ "8430:80" ]
environment: environment:
- SECRET=${SECRET} - SECRET=${SECRET}
- MAX_AGE_MS=${MAX_AGE_MS}
restart: unless-stopped
redis: redis:
hostname: sus-redis
image: redis:7 image: redis:7
volumes: [ "./redis:/data" ] volumes: [ "./redis:/data" ]
ports: [ "6379:6379" ] restart: unless-stopped

135
main.go
View File

@@ -25,7 +25,6 @@ import (
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
@@ -34,17 +33,19 @@ import (
) )
var client = redis.NewClient(&redis.Options{ var client = redis.NewClient(&redis.Options{
Addr: "redis:6379", Addr: "sus-redis:6379",
Password: "", Password: "",
DB: 0, DB: 0,
}) })
var SECRET string var SECRET string
var INDEX_GET_REDIRECT = "http://alv.cx" var INDEX_GET_REDIRECT = "http://alv.cx"
var MAX_AGE_MS int64 = 500
func main() { func main() {
r := mux.NewRouter() r := mux.NewRouter()
r.HandleFunc("/{shortlink}", shortlinkHandler) r.HandleFunc("/{shortlink}", shortlinkHandler)
r.HandleFunc("/{shortlink}/", shortlinkHandler)
r.HandleFunc("/", indexHandler) r.HandleFunc("/", indexHandler)
listenAddress := "0.0.0.0:80" listenAddress := "0.0.0.0:80"
@@ -61,6 +62,14 @@ func main() {
INDEX_GET_REDIRECT = p INDEX_GET_REDIRECT = p
} }
if p, ok := os.LookupEnv("MAX_AGE_MS"); ok {
if v, err := strconv.ParseInt(p, 10, 64); err != nil {
fmt.Printf("Unable to parse environment variable MAX_AGE_MS: %v\n", p)
} else {
MAX_AGE_MS = v
}
}
log.Fatal(http.ListenAndServe(listenAddress, r)) log.Fatal(http.ListenAndServe(listenAddress, r))
} }
@@ -75,27 +84,30 @@ func indexHandler(w http.ResponseWriter, r *http.Request) {
command := r.PostForm.Get("Command") command := r.PostForm.Get("Command")
shortlink := r.PostForm.Get("Shortlink") shortlink := r.PostForm.Get("Shortlink")
redirect := r.PostForm.Get("Redirect") value := r.PostForm.Get("Value")
alt_redirect := r.PostForm.Get("AltRedirect") req_timestamp := r.PostForm.Get("Timestamp")
alt_condition := r.PostForm.Get("AltCondition") req_timestamp_int, err := strconv.ParseInt(req_timestamp, 10, 64)
fmt.Printf("command: %v, shortlink: %v, redirect: %v, alt_redirect: %v, alt_condition: %v\n", if err != nil {
command, w.WriteHeader(http.StatusBadRequest)
shortlink, w.Write([]byte("Bad request"))
redirect, return
alt_redirect, }
alt_condition,
)
formstring := command + ":" + shortlink + ":" + redirect + ":" + alt_condition + ":" + alt_redirect cur_timestamp := time.Now().UnixNano()
if req_timestamp_int+MAX_AGE_MS*1000*1000 < cur_timestamp {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Bad request"))
return
}
fmt.Println("formstring: " + formstring) fmt.Printf("req_timestamp: %v, command: %v, shortlink: %v, value: %v\n", req_timestamp, command, shortlink, value)
signature := r.Header.Get("Signature") signature := r.Header.Get("Signature")
calculatedSignature := fmt.Sprintf( calculatedSignature := fmt.Sprintf(
"SUS-SIGNATURE-%v", "SUS-SIGNATURE-%v",
getSha256HMACSignature( getSha256HMACSignature(
[]byte(SECRET), []byte(SECRET),
command+":"+shortlink+":"+redirect+":"+alt_condition+":"+alt_redirect, req_timestamp+":"+command+":"+shortlink+":"+value,
), ),
) )
@@ -113,21 +125,7 @@ func indexHandler(w http.ResponseWriter, r *http.Request) {
ctx := context.Background() ctx := context.Background()
_, err := client.Get(ctx, shortlink).Result() _, err := client.Get(ctx, shortlink).Result()
if err == redis.Nil { if err == redis.Nil {
err = client.Set(ctx, shortlink+":redirect", shortlink, 0).Err() err = client.Set(ctx, shortlink, value, 0).Err()
if err != nil {
fmt.Println(err)
w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error"))
return
}
err = client.Set(ctx, shortlink+":altcondition", alt_condition, 0).Err()
if err != nil {
fmt.Println(err)
w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error"))
return
}
err = client.Set(ctx, shortlink+":altredirect", alt_redirect, 0).Err()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@@ -155,16 +153,13 @@ func indexHandler(w http.ResponseWriter, r *http.Request) {
} }
if command == "delete" { if command == "delete" {
if value != "confirm" {
w.WriteHeader(400)
w.Write([]byte("400 Bad Request"))
}
ctx := context.Background() ctx := context.Background()
if err := client.Del(ctx, shortlink+":redirect").Err(); err != nil { if err := client.Del(ctx, shortlink).Err(); err != nil {
w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error"))
}
if err := client.Del(ctx, shortlink+":altredirect").Err(); err != nil {
w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error"))
}
if err := client.Del(ctx, shortlink+":altcondition").Err(); err != nil {
w.WriteHeader(500) w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error")) w.Write([]byte("500 Internal Server Error"))
} }
@@ -182,13 +177,13 @@ func indexHandler(w http.ResponseWriter, r *http.Request) {
resp := "" resp := ""
for _, key := range keys { for _, key := range keys {
shortlink, err = client.Get(ctx, key).Result() value, err = client.Get(ctx, key).Result()
if err == redis.Nil { if err == redis.Nil {
w.WriteHeader(500) w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error")) w.Write([]byte("500 Internal Server Error"))
return return
} }
resp += key + ":" + shortlink + "\n" resp += key + ":" + value + "\n"
} }
w.WriteHeader(200) w.WriteHeader(200)
@@ -201,9 +196,7 @@ func shortlinkHandler(w http.ResponseWriter, r *http.Request) {
fmt.Println("shortlinkHandler called") fmt.Println("shortlinkHandler called")
shortlink := string(mux.Vars(r)["shortlink"]) shortlink := string(mux.Vars(r)["shortlink"])
ctx := context.Background() ctx := context.Background()
client.Incr(ctx, shortlink + ":hits") redirect, err := client.Get(ctx, shortlink).Result()
redirect, err := client.Get(ctx, shortlink+":redirect").Result()
fmt.Printf("shortlink: %v, redirect: %v\n", shortlink, redirect) fmt.Printf("shortlink: %v, redirect: %v\n", shortlink, redirect)
if err == redis.Nil { if err == redis.Nil {
w.WriteHeader(404) w.WriteHeader(404)
@@ -211,67 +204,13 @@ func shortlinkHandler(w http.ResponseWriter, r *http.Request) {
return return
} else if err != nil { } else if err != nil {
fmt.Println(err) fmt.Println(err)
fmt.Println(0)
w.WriteHeader(500) w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error")) w.Write([]byte("500 Internal Server Error"))
return return
} }
altcondition, err := client.Get(ctx, shortlink+":altcondition").Result()
if err == redis.Nil {
http.Redirect(w, r, redirect, 302) http.Redirect(w, r, redirect, 302)
return return
} else if err != nil {
fmt.Println(err)
fmt.Println(1)
w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error"))
return
}
altredirect, err := client.Get(ctx, shortlink+":altredirect").Result()
if err == redis.Nil {
http.Redirect(w, r, redirect, 302)
return
} else if err != nil {
fmt.Println(err)
fmt.Println(2)
w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error"))
return
}
altcondition_split := strings.Split(altcondition, ",")
ac_varname := altcondition_split[0]
ac_operator := altcondition_split[1]
ac_required_value, _ := strconv.Atoi(altcondition_split[2])
var ac_varval int
if ac_varname != "timestamp" {
ac_varvalstr, err := client.Get(ctx, shortlink+":"+ac_varname).Result()
if err == redis.Nil {
ac_varval = 0
} else if err != nil {
fmt.Println(err)
fmt.Println(3)
w.WriteHeader(500)
w.Write([]byte("500 Internal Server Error"))
return
} else {
ac_varval, _ = strconv.Atoi(ac_varvalstr)
}
} else {
ac_varval = int(time.Now().Unix())
}
if (ac_operator == "eq" && ac_varval == ac_required_value) ||
(ac_operator == "gt" && ac_varval > ac_required_value) ||
(ac_operator == "lt" && ac_varval < ac_required_value) {
http.Redirect(w, r, altredirect, 307)
} else {
http.Redirect(w, r, redirect, 307)
}
} }
func getSha256HMACSignature(secret []byte, data string) string { func getSha256HMACSignature(secret []byte, data string) string {

View File

@@ -7,21 +7,11 @@ sus URL shortener
- creating a new shortlink at https://pls.cx/shortlink - creating a new shortlink at https://pls.cx/shortlink
susmng [-s pls.cx] create -l shortlink -r https://example.com susmng [-s pls.cx] create -l shortlink -v https://example.com
- creating a new shortlink at https://pls.cx/shortlink which redirects to https://example.com/a the first n times, https://example.com/b any other times
susmng [-s pls.cx] create -l shortlink -r https://example.com/a \
-c hits,gt,<n> -a https://example.com/b
- creating a new shortlink at https://pls.cx/shortlink which redirects to https://example.com/before before unix timestamp n (seconds), https://example.com/after after that
susmng [-s pls.cx] create -l shortlink -r https://example.com/before \
-c timestamp,gt,<n> -a https://example.com/after
- deleting the shortlink https://pls.cx/shortlink - deleting the shortlink https://pls.cx/shortlink
susmng [-s pls.cx] delete -l shortlink susmng [-s pls.cx] delete -l shortlink -v confirm
- listing all shortlinks on the server pls.cx - listing all shortlinks on the server pls.cx
@@ -49,6 +39,7 @@ flag is not provided.
- `LISTEN_ADDRESS`---the address the server is listening on (default is `0.0.0.0:80`) - `LISTEN_ADDRESS`---the address the server is listening on (default is `0.0.0.0:80`)
- `INDEX_GET_REDIRECT`---the URL the user should be redirected to if they try to access `/` on the - `INDEX_GET_REDIRECT`---the URL the user should be redirected to if they try to access `/` on the
server (default is `http://alv.cx`) server (default is `http://alv.cx`)
- `MAX_AGE_MS`---how old a request can be (in milliseconds) before the server will refuse to process it. (default is 500 milliseconds)
### setting up susmng ### setting up susmng

View File

@@ -8,6 +8,7 @@ import pathlib
import os import os
import json import json
import sys import sys
import time
def get_args(): def get_args():
@@ -18,10 +19,8 @@ def get_args():
parser.add_argument('command') parser.add_argument('command')
parser.add_argument('-s', '--server', default="") parser.add_argument('-s', '--server', default="")
parser.add_argument('-l', '--shortlink', default="") parser.add_argument('-l', '--shortlink', default="")
parser.add_argument('-r', '--redirect', default="") parser.add_argument('-v', '--value', default="")
parser.add_argument('-a', '--alt-redirect', default="") parser.add_argument('-c', '--config', type=pathlib.Path, default=pathlib.Path(os.path.expanduser('~/.config/susmng/config.json')))
parser.add_argument('-c', '--alt-condition', default="")
parser.add_argument('--config', type=pathlib.Path, default=pathlib.Path(os.path.expanduser('~/.config/susmng/config.json')))
parser.add_argument('-H', '--http', action='store_true') parser.add_argument('-H', '--http', action='store_true')
return parser.parse_args() return parser.parse_args()
@@ -50,25 +49,37 @@ def main(args):
secret = config['secrets'][server] secret = config['secrets'][server]
formstring = args.command+":"+args.shortlink+":"+args.redirect+":"+args.alt_condition+":"+args.alt_redirect if args.command == "delete" and args.value != "confirm":
print(f"{formstring=}") print("--value not set to 'confirm'... delete operation may fail")
# accoring to python documentation (https://docs.python.org/3/library/time.html#time.time)
# this function does not explicitly have to use unix time, and implementation is dependent
# platform.
# most platforms (windows, unix) will probably give unix time though.
#
# the server side (main.go file) does explicitly use unix time (time.Now().UnixNano()) to get
# this number, but hopefully there should be no issues on most platforms.
timestamp = str(time.time_ns())
r = requests.post(f"{'http' if args.http else 'https'}://{server}",
data = { data = {
'Command': args.command, 'Command': args.command,
'Shortlink': args.shortlink, 'Shortlink': args.shortlink,
'Redirect': args.redirect, 'Value': args.value,
'AltCondition': args.alt_condition, 'Timestamp': timestamp,
'AltRedirect': args.alt_redirect, }
},
headers = { headers = {
'Signature': 'SUS-SIGNATURE-' + hmac.new( 'Signature': 'SUS-SIGNATURE-' + hmac.new(
secret.encode("UTF-8"), secret.encode("UTF-8"),
formstring.encode("UTF-8"), (timestamp + ":" + args.command + ":" + args.shortlink + ":" + args.value).encode("UTF-8"),
hashlib.sha256 hashlib.sha256
).hexdigest() ).hexdigest()
} }
)
print(f"{data=}")
print(f"{headers=}")
r = requests.post(f"{'http' if args.http else 'https'}://{server}", data=data, headers=headers)
print(r, file=sys.stderr) print(r, file=sys.stderr)
print(r.content.decode().strip()) print(r.content.decode().strip())
return 0 return 0