// sus - a simple url shortener Copyright (C) 2022 Akbar Rahman (hi@alv.cx) // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . package main import ( "crypto/hmac" "crypto/sha256" "encoding/hex" "fmt" "io" "log" "net/http" "os" "strconv" "time" "github.com/go-redis/redis/v8" "github.com/gorilla/mux" "golang.org/x/net/context" ) var client = redis.NewClient(&redis.Options{ Addr: "sus-redis:6379", Password: "", DB: 0, }) var SECRET string var INDEX_GET_REDIRECT = "http://alv.cx" var MAX_AGE_MS int64 = 500 func main() { r := mux.NewRouter() r.HandleFunc("/{shortlink}", shortlinkHandler) r.HandleFunc("/{shortlink}/", shortlinkHandler) r.HandleFunc("/", indexHandler) listenAddress := "0.0.0.0:80" if p, ok := os.LookupEnv("LISTEN_ADDRESS"); ok { listenAddress = p } if p, ok := os.LookupEnv("SECRET"); ok { SECRET = p } if p, ok := os.LookupEnv("INDEX_GET_REDIRECT"); ok { INDEX_GET_REDIRECT = p } if p, ok := os.LookupEnv("MAX_AGE_MS"); ok { if v, err := strconv.ParseInt(p, 10, 64); err != nil { fmt.Printf("Unable to parse environment variable MAX_AGE_MS: %v\n", p) } else { MAX_AGE_MS = v } } log.Fatal(http.ListenAndServe(listenAddress, r)) } func indexHandler(w http.ResponseWriter, r *http.Request) { fmt.Println("indexHandler called") if r.Method != "POST" { http.Redirect(w, r, INDEX_GET_REDIRECT, 302) return } r.ParseForm() command := r.PostForm.Get("Command") shortlink := r.PostForm.Get("Shortlink") value := r.PostForm.Get("Value") req_timestamp := r.PostForm.Get("Timestamp") req_timestamp_int, err := strconv.ParseInt(req_timestamp, 10, 64) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("Bad request")) return } cur_timestamp := time.Now().UnixNano() if req_timestamp_int+MAX_AGE_MS*1000*1000 < cur_timestamp { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("Bad request")) return } fmt.Printf("req_timestamp: %v, command: %v, shortlink: %v, value: %v\n", req_timestamp, command, shortlink, value) signature := r.Header.Get("Signature") calculatedSignature := fmt.Sprintf( "SUS-SIGNATURE-%v", getSha256HMACSignature( []byte(SECRET), req_timestamp+":"+command+":"+shortlink+":"+value, ), ) if signature != calculatedSignature { fmt.Println("signature do no match") fmt.Println(signature) fmt.Println(calculatedSignature) w.WriteHeader(401) w.Write([]byte("401 Unauthorized")) return } if command == "create" { ctx := context.Background() _, err := client.Get(ctx, shortlink).Result() if err == redis.Nil { err = client.Set(ctx, shortlink, value, 0).Err() if err != nil { fmt.Println(err) w.WriteHeader(500) w.Write([]byte("500 Internal Server Error")) return } w.WriteHeader(200) w.Write([]byte("200 Success")) return } else if err != nil { fmt.Println(err) w.WriteHeader(500) w.Write([]byte("500 Internal Server Error")) return } fmt.Println(err) w.WriteHeader(403) w.Write([]byte("403 Forbidden")) return } if command == "delete" { if value != "confirm" { w.WriteHeader(400) w.Write([]byte("400 Bad Request")) } ctx := context.Background() if err := client.Del(ctx, shortlink).Err(); err != nil { w.WriteHeader(500) w.Write([]byte("500 Internal Server Error")) } } if command == "list" { ctx := context.Background() keys, err := client.Keys(ctx, "*").Result() if err != nil { fmt.Println(err) w.WriteHeader(500) w.Write([]byte("500 Internal Server Error")) return } resp := "" for _, key := range keys { value, err = client.Get(ctx, key).Result() if err == redis.Nil { w.WriteHeader(500) w.Write([]byte("500 Internal Server Error")) return } resp += key + ":" + value + "\n" } w.WriteHeader(200) w.Write([]byte(resp)) } } func shortlinkHandler(w http.ResponseWriter, r *http.Request) { fmt.Println("shortlinkHandler called") shortlink := string(mux.Vars(r)["shortlink"]) ctx := context.Background() redirect, err := client.Get(ctx, shortlink).Result() fmt.Printf("shortlink: %v, redirect: %v\n", shortlink, redirect) if err == redis.Nil { w.WriteHeader(404) w.Write([]byte("404 Not Found")) return } else if err != nil { fmt.Println(err) w.WriteHeader(500) w.Write([]byte("500 Internal Server Error")) return } http.Redirect(w, r, redirect, 302) return } func getSha256HMACSignature(secret []byte, data string) string { h := hmac.New(sha256.New, secret) io.WriteString(h, data) return hex.EncodeToString(h.Sum(nil)) }