package main
import (
"database/sql"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"time"
errors "git.sequentialread.com/forest/pkg-errors"
_ "github.com/lib/pq"
)
type RowScanner interface {
Scan ( src ... interface { } ) error
}
type DBModel struct {
DB * sql . DB
subdomainRegex * regexp . Regexp
PortRangeSize int
MinListeningPort int
MaxListeningPort int
}
type TunnelSettings struct {
PortStart int
PortEnd int
AuthorizedDomains [ ] string
}
func ( settings * TunnelSettings ) DeepEquals ( other * TunnelSettings ) bool {
if domainNamesToString ( settings . AuthorizedDomains ) != domainNamesToString ( other . AuthorizedDomains ) {
return false
}
if settings . PortStart != other . PortStart || settings . PortEnd != other . PortEnd {
return false
}
return true
}
type TenantInfo struct {
Id int
Created time . Time
Email string
Subdomain string
DedicatedVPSCount int
Bytes int64
SMSAlarmNumber string
ServiceLimitCents int
BillingAlarmCents int
PortBucket int
TunnelSettings * TunnelSettings
Deactivated bool
APITokens [ ] APIToken
ExternalDomains [ ] ExternalDomain
}
type TenantVPSInstance struct {
TenantId int
ServiceProvider string
ProviderInstanceId string
ShadowConfig * TunnelSettings
Bytes int64
Active bool
DeactivatedAt * time . Time
}
type APIToken struct {
Name string
Active bool
HashedToken string
Created time . Time
LastUsed time . Time
}
type ExternalDomain struct {
DomainName string
IsValid bool
}
const DomainVerificationPollingInterval = time . Hour
func ( i * TenantVPSInstance ) GetVPSInstanceId ( ) string {
return fmt . Sprintf ( "%s-%s" , i . ServiceProvider , i . ProviderInstanceId )
}
func initDatabase ( config * Config ) * DBModel {
desiredSchemaVersion := 3
db , err := sql . Open ( config . DatabaseType , config . DatabaseConnectionString )
if err != nil {
log . Fatal ( err )
}
if err := db . Ping ( ) ; err != nil {
log . Fatalf ( "failed to open database connection: %+v" , err )
}
var tableName string
err = db . QueryRow (
"SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2" ,
config . DatabaseSchema , "schema_version" ,
) . Scan ( & tableName )
if err == sql . ErrNoRows {
_ , err := db . Exec ( `
CREATE TABLE schema_version (
version INT PRIMARY KEY NOT NULL
) ;
INSERT INTO schema_version ( version ) VALUES ( 1 ) ;
` )
if err != nil {
log . Fatalf ( "failed to create schema_version table: %+v" , err )
}
} else if err != nil {
log . Fatalf ( "failed to check whether or not schema_version table exists: %+v" , err )
}
readSchemaVersionFromDatabase := func ( ) int {
var currentSchemaVersion int
err = db . QueryRow ( "SELECT version FROM schema_version" ) . Scan ( & currentSchemaVersion )
if err != nil {
log . Fatalf ( "failed to select currentSchemaVersion: %+v" , err )
}
return currentSchemaVersion
}
currentSchemaVersion := readSchemaVersionFromDatabase ( )
files , err := ioutil . ReadDir ( "schema_versions" )
if err != nil {
log . Fatalf ( "failed to list schema_versions: %+v" , err )
}
getMigrationScript := func ( version int , direction string ) ( filename string , content string ) {
prefix := fmt . Sprintf ( "%02d_%s_" , version , direction )
for _ , file := range files {
if ! file . IsDir ( ) && strings . HasPrefix ( file . Name ( ) , prefix ) && strings . HasSuffix ( file . Name ( ) , ".sql" ) {
filename = filepath . Join ( "schema_versions" , file . Name ( ) )
contentsBytes , err := ioutil . ReadFile ( filename )
if err != nil {
log . Fatalf ( "failed to read file '%s': %+v" , filename , err )
}
content = fmt . Sprintf ( `
BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE ;
% s
UPDATE schema_version SET version = % d ;
COMMIT TRANSACTION ;
` , string ( contentsBytes ) , version )
break
}
}
if content == "" {
log . Fatalf ( "didn't find any files in schema_versions matching %s*.sql" , prefix )
}
return filename , content
}
for currentSchemaVersion != desiredSchemaVersion {
log . Printf (
"currentSchemaVersion (%d) != desiredSchemaVersion (%d), running database schema migration(s)\n" ,
currentSchemaVersion , desiredSchemaVersion ,
)
var expectedSchemaVersion int
var filename , content string
if currentSchemaVersion < desiredSchemaVersion {
expectedSchemaVersion = currentSchemaVersion + 1
filename , content = getMigrationScript ( expectedSchemaVersion , "up" )
_ , err = db . Exec ( content )
if err != nil {
log . Fatalf ( "failed to execute database migration script %s: %+v" , filename , err )
}
} else {
expectedSchemaVersion = currentSchemaVersion - 1
filename , content = getMigrationScript ( currentSchemaVersion , "down" )
_ , err = db . Exec ( content )
if err != nil {
log . Fatalf ( "failed to execute database migration script %s: %+v" , filename , err )
}
}
actualSchemaVersion := readSchemaVersionFromDatabase ( )
if expectedSchemaVersion != actualSchemaVersion {
log . Fatalf (
"expecting database schema version (%d) to be %d after running database migration script %s" ,
actualSchemaVersion , expectedSchemaVersion , filename ,
)
}
currentSchemaVersion = actualSchemaVersion
}
return & DBModel {
DB : db ,
subdomainRegex : regexp . MustCompile ( "^[a-z0-9]([a-z0-9-_]*[a-z0-9]+)?$" ) ,
// TODO make these configurable?
PortRangeSize : 20 ,
MinListeningPort : 10000 ,
MaxListeningPort : 30000 ,
}
}
// ---------------- DBModel methods ----------------
func ( model * DBModel ) Register ( email , hashedPassword string ) ( int , error ) {
var existingAccount string
err := model . DB . QueryRow ( "SELECT id FROM tenants WHERE email = $1" , strings . ToLower ( email ) ) . Scan ( & existingAccount )
if err != sql . ErrNoRows {
return 0 , fmt . Errorf ( "email address '%s' is already associated with an account" , strings . ToLower ( email ) )
}
var inserted int
err = model . DB . QueryRow (
"INSERT INTO tenants (email, hashed_password) VALUES ($1, $2) RETURNING id" ,
strings . ToLower ( email ) , hashedPassword ,
) . Scan ( & inserted )
if err != nil {
return 0 , errors . Wrap ( err , "Register(): could not insert row into tenants table" )
}
return int ( inserted ) , nil
}
func ( model * DBModel ) CreateEmailVerificationToken ( token string , tenantId int , expires time . Time ) error {
//log.Printf("CreateEmailVerificationToken(): \n%s \n%d\n", expires.UTC().String(), expires.Unix())
_ , err := model . DB . Exec (
"INSERT INTO email_verification_tokens (token, tenant_id, expires) VALUES ($1, $2, $3)" ,
token , tenantId , expires . UTC ( ) ,
)
if err != nil {
return errors . Wrap ( err , "CreateEmailVerificationToken(): could not insert row into email_verification_tokens table" )
}
return nil
}
func ( model * DBModel ) VerifyEmail ( token string , tenantId int ) error {
var expires time . Time
err := model . DB . QueryRow (
"SELECT expires FROM email_verification_tokens WHERE token = $1 AND tenant_id = $2" ,
token , tenantId ,
) . Scan ( & expires )
// log.Println("VerifyEmail():")
// log.Println(expires)
// log.Println(expires.Unix())
if err != nil && err != sql . ErrNoRows {
log . Printf ( "VerifyEmail(): query error %+v" , err )
}
if err != nil || time . Now ( ) . UTC ( ) . After ( expires ) {
return errors . New ( "email verification token was invalid or expired" )
} else {
model . DB . Exec ( "DELETE FROM email_verification_tokens WHERE token = $1" , token )
_ , err := model . DB . Exec ( "UPDATE tenants SET email_verified = TRUE WHERE id = $1" , tenantId )
if err != nil {
return errors . New ( "internal error occurred during email verification" )
}
}
return nil
}
func ( model * DBModel ) GetLoginInfo ( email string ) ( int , string , bool ) {
tenantId := 0
var hashedPassword string
var emailVerified bool
err := model . DB . QueryRow (
"SELECT id, hashed_password, email_verified FROM tenants WHERE email = $1" ,
strings . ToLower ( email ) ,
) . Scan ( & tenantId , & hashedPassword , & emailVerified )
if tenantId != 0 && err == nil {
return tenantId , hashedPassword , emailVerified
} else {
return tenantId , "" , false
}
}
func ( model * DBModel ) GetSession ( id string , cameFromLaxCookie bool ) ( * Session , error ) {
var loggedInTenantId int
var emailVerified bool
var email string
var expires time . Time
requireLaxCookie := ""
if cameFromLaxCookie {
requireLaxCookie = "and tenants.lax_cookie = TRUE"
}
err := model . DB . QueryRow (
fmt . Sprintf ( `
SELECT session_cookies . tenant_id , tenants . email , tenants . email_verified , session_cookies . expires
FROM session_cookies JOIN tenants on session_cookies . tenant_id = tenants . id
WHERE session_cookies . id = $ 1 % s ` ,
requireLaxCookie ,
) ,
id ,
) . Scan ( & loggedInTenantId , & email , & emailVerified , & expires )
// log.Println("GetSession():")
// log.Println(expires.UTC())
// log.Println(expires.Unix())
if err == sql . ErrNoRows {
return nil , nil
}
if err != nil {
return nil , errors . Wrapf ( err , "GetSession(id=%s, cameFromLaxCookie=%t): " , id , cameFromLaxCookie )
}
return & Session {
TenantId : loggedInTenantId ,
Email : email ,
EmailVerified : emailVerified ,
Expires : expires . UTC ( ) ,
} , nil
}
func ( model * DBModel ) SetSession ( id string , session * Session ) error {
_ , err := model . DB . Exec ( "UPDATE tenants SET lax_cookie = $1 WHERE id = $2" , session . LaxCookie , session . TenantId )
if err != nil {
return errors . Wrap ( err , "SetSession(): " )
}
_ , err = model . DB . Exec ( "DELETE FROM session_cookies WHERE tenant_id = $1" , session . TenantId )
if err != nil {
return errors . Wrap ( err , "SetSession(): " )
}
_ , err = model . DB . Exec ( "INSERT INTO session_cookies (id, tenant_id, expires) VALUES ($1, $2, $3)" , id , session . TenantId , session . Expires . UTC ( ) )
return errors . Wrap ( err , "SetSession(): " )
}
func ( model * DBModel ) LogoutTenant ( tenantId int ) error {
_ , err := model . DB . Exec ( "DELETE FROM session_cookies WHERE tenant_id = $1" , tenantId )
return err
}
func ( model * DBModel ) GetUserByAPIToken ( hashedApiToken string ) ( * Session , error ) {
var loggedInTenantId int
var email string
var emailVerified bool
if len ( hashedApiToken ) < 8 {
return nil , errors . New ( "The given hashedApiToken token was too short" )
}
err := model . DB . QueryRow (
` SELECT tenant_id FROM api_tokens WHERE hashed_token = $1 AND active = TRUE ` ,
hashedApiToken ,
) . Scan ( & loggedInTenantId )
if err == sql . ErrNoRows {
return nil , nil
}
if err != nil {
return nil , errors . Wrapf ( err , "GetUserByAPIToken(hashedApiToken=%s): " , hashedApiToken )
}
err = model . DB . QueryRow (
` SELECT email, email_verified FROM tenants WHERE id = $1 ` , loggedInTenantId ,
) . Scan ( & email , & emailVerified )
if err != nil {
return nil , errors . Wrapf ( err , "GetUserByAPIToken(hashedApiToken=%s): " , hashedApiToken )
}
return & Session {
TenantId : loggedInTenantId ,
Email : email ,
EmailVerified : emailVerified ,
} , nil
}
func ( model * DBModel ) SetFreeSubdomain ( tenantId int , subdomain string ) ( bool , error ) {
subdomain = strings . ToLower ( subdomain )
if ! model . subdomainRegex . MatchString ( subdomain ) {
return false , errors . Errorf ( "SetFreeSubdomain(): subdomain '%s' is invalid" , subdomain )
}
rows , err := model . DB . Query ( ` SELECT 1 FROM tenants WHERE subdomain = $1 ` , subdomain )
if err != nil {
return false , errors . Wrap ( err , "SetFreeSubdomain(): " )
}
defer rows . Close ( )
if rows . Next ( ) {
return true , nil
}
_ , err = model . DB . Exec ( "UPDATE tenants SET subdomain = $1 WHERE id = $2" , subdomain , tenantId )
if err != nil {
return false , errors . Wrap ( err , "SetFreeSubdomain(): " )
}
return false , nil
}
func ( model * DBModel ) SetReservedPorts ( tenantId , portStart , portEnd , portBucket int ) error {
_ , err := model . DB . Exec (
"UPDATE tenants SET port_start = $1, port_end = $2, port_bucket = $3 WHERE id = $4" ,
portStart , portEnd , portBucket , tenantId ,
)
if err != nil {
return errors . Wrap ( err , "SetReservedPorts(): " )
}
return nil
}
func ( model * DBModel ) CreateAPIToken ( tenantId int , keyName , hashedAPIToken string ) error {
_ , err := model . DB . Exec (
"INSERT INTO api_tokens (tenant_id, key_name, hashed_token) VALUES ($1, $2, $3)" ,
tenantId , keyName , hashedAPIToken ,
)
if err != nil {
return errors . Wrap ( err , "CreateAPIToken(): " )
}
return nil
}
func ( model * DBModel ) SetAPITokenActive ( tenantId int , keyName string , active bool ) error {
_ , err := model . DB . Exec (
"UPDATE api_tokens SET active = $1 WHERE tenant_id = $2 AND keyName = $3" ,
active , tenantId , keyName ,
)
if err != nil {
return errors . Wrap ( err , "SetAPITokenActive(): " )
}
return nil
}
func ( model * DBModel ) DeleteAPIToken ( tenantId int , keyName string ) error {
_ , err := model . DB . Exec (
"DELETE FROM api_tokens WHERE tenant_id = $1 AND keyName = $2" ,
tenantId , keyName ,
)
if err != nil {
return errors . Wrap ( err , "DeleteAPIToken(): " )
}
return nil
}
func ( model * DBModel ) GetNextReservedPorts ( ) ( int , int , int , error ) {
port := 0
bucket := 0
err := model . DB . QueryRow ( "SELECT port, bucket FROM reserved_ports_counter" ) . Scan ( & port , & bucket )
if err != nil {
return 0 , - 1 , - 1 , errors . Wrap ( err , "GetNextReservedPorts(): " )
}
if port + model . PortRangeSize >= model . MaxListeningPort {
bucket ++
port = model . MinListeningPort
}
_ , err = model . DB . Exec ( "UPDATE reserved_ports_counter SET port = $1, bucket = $2" , port + model . PortRangeSize , bucket )
if err != nil {
return 0 , - 1 , - 1 , errors . Wrap ( err , "GetNextReservedPorts(): " )
}
return port , port + ( model . PortRangeSize - 1 ) , bucket , nil
}
func ( model * DBModel ) GetVPSInstances ( ) ( map [ string ] * VPSInstance , error ) {
rows , err := model . DB . Query ( `
SELECT service_provider , provider_instance_id , tenant_id , ipv4 , ipv6 , bytes_monthly , created , deprecated , deleted
FROM vps_instances WHERE deleted = FALSE
` ,
)
if err != nil {
return nil , errors . Wrap ( err , "GetVPSInstances(): " )
}
defer rows . Close ( )
toReturn := map [ string ] * VPSInstance { }
for rows . Next ( ) {
instance , err := model . rowToVPSInstance ( rows )
if err != nil {
return nil , errors . Wrap ( err , "GetVPSInstances(): " )
}
toReturn [ instance . GetId ( ) ] = instance
}
return toReturn , nil
}
func ( model * DBModel ) CreateVPSInstance ( toCreate * VPSInstance ) error {
_ , err := model . DB . Exec ( `
INSERT INTO vps_instances (
service_provider , provider_instance_id , created ,
ipv4 , ipv6 , bytes_monthly
)
VALUES ( $ 1 , $ 2 , $ 3 ,
$ 4 , $ 5 , $ 6 )
` , toCreate . ServiceProvider , toCreate . ProviderInstanceId , toCreate . Created ,
toCreate . IPV4 , toCreate . IPV6 , toCreate . BytesMonthly ,
)
if err != nil {
return errors . Wrap ( err , "CreateVPSInstance(): " )
}
if toCreate . TenantId != 0 {
_ , err := model . DB . Exec (
` UPDATE vps_instances SET tenantId = $1 WHERE service_provider = $2 AND provider_instance_id = $3 ` ,
toCreate . TenantId , toCreate . ServiceProvider , toCreate . ProviderInstanceId ,
)
if err != nil {
return errors . Wrap ( err , "CreateVPSInstance(): " )
}
}
return nil
}
func ( model * DBModel ) DeleteVPSInstance ( provider , providerInstanceId string ) error {
result , err := model . DB . Exec (
` UPDATE vps_instances SET deleted = TRUE WHERE service_provider = $1 AND provider_instance_id = $2 ` ,
provider , providerInstanceId ,
)
if err != nil {
return errors . Wrap ( err , "DeleteVPSInstance(): " )
}
rowsAffected , err := result . RowsAffected ( )
if err != nil {
return errors . Wrap ( err , "DeleteVPSInstance(): " )
}
if rowsAffected == 0 {
return errors . Errorf ( "DeleteVPSInstance(): '%s-%s' was not found" , provider , providerInstanceId )
}
return nil
}
func ( model * DBModel ) GetTenants ( ) ( map [ int ] * TenantInfo , error ) {
rows , err := model . DB . Query ( ` SELECT tenant_id, domain_name, last_verified FROM external_domains ` )
if err != nil {
return nil , errors . Wrap ( err , "GetTenants(): " )
}
defer rows . Close ( )
verificationCutoff := time . Now ( ) . UTC ( ) . Add ( - ( DomainVerificationPollingInterval + time . Minute ) )
authorizedDomains := map [ int ] [ ] string { }
for rows . Next ( ) {
var tenantId int
var domainName string
var lastVerified time . Time
err := rows . Scan ( & tenantId , & domainName , & lastVerified )
if err != nil {
return nil , errors . Wrap ( err , "GetTenants(): " )
}
if lastVerified . After ( verificationCutoff ) {
if _ , has := authorizedDomains [ tenantId ] ; ! has {
authorizedDomains [ tenantId ] = [ ] string { domainName }
} else {
authorizedDomains [ tenantId ] = append ( authorizedDomains [ tenantId ] , domainName )
}
}
}
rows , err = model . DB . Query ( ` SELECT id, created, subdomain, service_limit_cents, port_start, port_end, port_bucket FROM tenants ` )
if err != nil {
return nil , errors . Wrap ( err , "GetTenants(): " )
}
defer rows . Close ( )
toReturn := map [ int ] * TenantInfo { }
for rows . Next ( ) {
var tenantId int
var tenantCreated time . Time
var subdomain * string
var serviceLimitCents int
var portStart int
var portEnd int
var portBucket int
err := rows . Scan ( & tenantId , & tenantCreated , & subdomain , & serviceLimitCents , & portStart , & portEnd , & portBucket )
if err != nil {
return nil , errors . Wrap ( err , "GetTenants(): " )
}
thisTenantDomains := authorizedDomains [ tenantId ]
if thisTenantDomains == nil {
thisTenantDomains = [ ] string { }
}
// thisTenantPorts := reservedPorts[tenantId]
// if thisTenantPorts == nil {
// thisTenantPorts = []PortRange{}
// }
subdomainString := ""
if subdomain != nil {
subdomainString = * subdomain
thisTenantDomains = append ( thisTenantDomains , fmt . Sprintf ( "%s.%s" , subdomainString , freeSubdomainDomain ) )
}
toReturn [ tenantId ] = & TenantInfo {
Id : tenantId ,
Created : tenantCreated ,
Subdomain : subdomainString ,
ServiceLimitCents : serviceLimitCents ,
PortBucket : portBucket ,
TunnelSettings : & TunnelSettings {
PortStart : portStart ,
PortEnd : portEnd ,
AuthorizedDomains : thisTenantDomains ,
} ,
}
}
return toReturn , nil
}
func ( model * DBModel ) GetTenant ( tenantId int ) ( * TenantInfo , error ) {
rows , err := model . DB . Query (
` SELECT domain_name, last_verified FROM external_domains WHERE tenant_id = $1 ` ,
tenantId ,
)
if err != nil {
return nil , errors . Wrapf ( err , "GetTenant(%d): " , tenantId )
}
defer rows . Close ( )
verificationCutoff := time . Now ( ) . UTC ( ) . Add ( - ( DomainVerificationPollingInterval + time . Minute ) )
authorizedDomains := [ ] string { }
externalDomains := [ ] ExternalDomain { }
for rows . Next ( ) {
var domainName string
var lastVerified time . Time
err := rows . Scan ( & domainName , & lastVerified )
if err != nil {
return nil , errors . Wrapf ( err , "GetTenant(%d): " , tenantId )
}
verified := lastVerified . After ( verificationCutoff )
if verified {
authorizedDomains = append ( authorizedDomains , domainName )
}
externalDomains = append ( externalDomains , ExternalDomain { DomainName : domainName , IsValid : verified } )
}
tokensRows , err := model . DB . Query (
` SELECT key_name, hashed_token, active, created, last_used FROM api_tokens WHERE tenant_id = $1 ` ,
tenantId ,
)
if err != nil {
return nil , errors . Wrapf ( err , "GetTenant(%d): " , tenantId )
}
defer tokensRows . Close ( )
apiTokens := [ ] APIToken { }
for tokensRows . Next ( ) {
var keyName string
var hashedToken string
var active bool
var created time . Time
var lastUsed time . Time
err := tokensRows . Scan ( & keyName , & hashedToken , & active , & created , & lastUsed )
if err != nil {
return nil , errors . Wrapf ( err , "GetTenant(%d): " , tenantId )
}
apiTokens = append ( apiTokens , APIToken {
Name : keyName ,
HashedToken : hashedToken ,
Active : active ,
Created : created ,
LastUsed : lastUsed ,
} )
}
var created time . Time
var subdomain * string
var email string
var smsAlarmNumber * string
var billingAlarmCents int
var serviceLimitCents int
var portStart int
var portEnd int
var portBucket int
err = model . DB . QueryRow (
` SELECT created , email , subdomain , sms_alarm_number , billing_alarm_cents , service_limit_cents ,
port_start , port_end , port_bucket
FROM tenants WHERE id = $ 1 ` ,
tenantId ,
) . Scan (
& created , & email , & subdomain , & smsAlarmNumber , & billingAlarmCents , & serviceLimitCents ,
& portStart , & portEnd , & portBucket ,
)
if err != nil {
return nil , errors . Wrapf ( err , "GetTenant(%d): " , tenantId )
}
subdomainString := ""
if subdomain != nil {
subdomainString = * subdomain
authorizedDomains = append ( authorizedDomains , fmt . Sprintf ( "%s.%s" , subdomainString , freeSubdomainDomain ) )
}
smsString := ""
if smsAlarmNumber != nil {
smsString = * smsAlarmNumber
}
return & TenantInfo {
Id : tenantId ,
Email : email ,
Created : created ,
SMSAlarmNumber : smsString ,
Subdomain : subdomainString ,
BillingAlarmCents : billingAlarmCents ,
ServiceLimitCents : serviceLimitCents ,
PortBucket : portBucket ,
TunnelSettings : & TunnelSettings {
PortStart : portStart ,
PortEnd : portEnd ,
AuthorizedDomains : authorizedDomains ,
} ,
APITokens : apiTokens ,
ExternalDomains : externalDomains ,
} , nil
}
func ( model * DBModel ) AddExternalDomain ( tenantId int , externalDomain string ) error {
_ , err := model . DB . Exec ( "INSERT INTO external_domains (tenant_id, domain_name) VALUES ($1, $2)" , tenantId , externalDomain )
if err != nil {
return errors . Wrap ( err , "AddExternalDomain(): " )
}
return nil
}
func ( model * DBModel ) GetExternalDomains ( ) ( [ ] [ ] string , error ) {
rows , err := model . DB . Query ( ` SELECT id, subdomain FROM tenants ` )
if err != nil {
return nil , errors . Wrap ( err , "GetExternalDomains(): " )
}
defer rows . Close ( )
personalDomainsByTenant := map [ int ] string { }
for rows . Next ( ) {
var tenantId int
var subdomain * string
err := rows . Scan ( & tenantId , & subdomain )
if err != nil {
return nil , errors . Wrap ( err , "GetExternalDomains(): " )
}
if subdomain != nil {
personalDomainsByTenant [ tenantId ] = fmt . Sprintf ( "%s.%s" , * subdomain , freeSubdomainDomain )
}
}
externalDomainsRows , err := model . DB . Query ( ` SELECT tenant_id, domain_name FROM external_domains ` )
if err != nil {
return nil , errors . Wrap ( err , "GetTenants(): " )
}
defer externalDomainsRows . Close ( )
toReturn := [ ] [ ] string { }
for externalDomainsRows . Next ( ) {
var tenantId int
var externalDomain string
err := externalDomainsRows . Scan ( & tenantId , & externalDomain )
if err != nil {
return nil , errors . Wrap ( err , "GetExternalDomains(): " )
}
personalDomain , hasPersonalDomain := personalDomainsByTenant [ tenantId ]
if hasPersonalDomain {
toReturn = append ( toReturn , [ ] string { externalDomain , personalDomain } )
}
}
return toReturn , nil
}
func ( model * DBModel ) MarkExternalDomainAsVerified ( externalDomain string ) error {
result , err := model . DB . Exec ( "UPDATE external_domains SET last_verified = NOW() WHERE domain_name = $1" , externalDomain )
if err != nil {
return errors . Wrap ( err , "MarkExternalDomainAsVerified(): " )
}
affected , err := result . RowsAffected ( )
if err != nil {
return errors . Wrap ( err , "MarkExternalDomainAsVerified(): " )
}
if affected != 1 {
return errors . Errorf ( "zero rows were affected by MarkExternalDomainAsVerified('%s')" , externalDomain )
}
return nil
}
func ( model * DBModel ) GetTenantVPSInstanceRows ( billingYear , billingMonth int ) ( [ ] * TenantVPSInstance , error ) {
// tenantCondition := ""
// if tenantId > 0 {
// tenantCondition = "AND tenant_vps_instance.tenant_id = $3"
// }
rows , err := model . DB . Query ( `
SELECT
tenant_id ,
service_provider ,
provider_instance_id ,
shadow_config ,
bytes ,
active ,
deactivated_at
FROM tenant_vps_instance
WHERE billing_year = $ 1 AND billing_month = $ 2
` , billingYear , billingMonth ,
)
if err != nil {
return nil , errors . Wrap ( err , "GetTenantVPSInstanceRows(): " )
}
defer rows . Close ( )
toReturn := [ ] * TenantVPSInstance { }
for rows . Next ( ) {
var tenantId int
var serviceProvider string
var serviceProviderInstanceId string
var shadowConfigString string
var bytes int64
var active bool
var deactivatedAt * time . Time
err := rows . Scan ( & tenantId , & serviceProvider , & serviceProviderInstanceId , & shadowConfigString , & bytes , & active , & deactivatedAt )
if err != nil {
return nil , errors . Wrap ( err , "GetTenantVPSInstanceRows(): " )
}
var shadowConfig TunnelSettings
err = json . Unmarshal ( [ ] byte ( shadowConfigString ) , & shadowConfig )
if err != nil {
return nil , errors . Wrap ( err , "GetTenantVPSInstanceRows(): " )
}
toReturn = append ( toReturn , & TenantVPSInstance {
TenantId : tenantId ,
ServiceProvider : serviceProvider ,
ProviderInstanceId : serviceProviderInstanceId ,
ShadowConfig : & shadowConfig ,
Bytes : bytes ,
Active : active ,
DeactivatedAt : deactivatedAt ,
} )
}
return toReturn , nil
}
func ( model * DBModel ) RecordTenantsUsage ( usage map [ int ] int64 ) error {
actions := [ ] func ( ) taskResult { }
for tenantId , bytez := range usage {
tenantId := tenantId
bytez := bytez
actions = append ( actions , func ( ) taskResult {
_ , err := model . DB . Exec ( ` INSERT INTO tenant_metrics_bandwidth (tenant_id, bytes) VALUES ($1, $2) ` , tenantId , bytez )
return taskResult { Name : strconv . Itoa ( tenantId ) , Err : err }
} )
}
results := doInParallel ( false , actions ... )
errorStrings := [ ] string { }
for tenantId , result := range results {
if result . Err != nil {
errorStrings = append ( errorStrings , fmt . Sprintf ( "tenant %s: %+v" , tenantId , result . Err ) )
}
}
if len ( errorStrings ) != 0 {
return errors . Errorf ( "RecordTenantsUsage(): \n%s\n" , strings . Join ( errorStrings , "\n" ) )
}
return nil
}
func ( model * DBModel ) RecordVPSUsage ( instance * VPSInstance , usage ThresholdMetrics , billingYear int , billingMonth int ) error {
bytesByTenant := map [ string ] int64 { }
for k , v := range usage . InboundByTenant {
bytesByTenant [ k ] += v
}
for k , v := range usage . OutboundByTenant {
bytesByTenant [ k ] += v
}
actions := [ ] func ( ) taskResult { }
for tenantIdString , bytez := range bytesByTenant {
tenantIdInt , err := strconv . Atoi ( tenantIdString )
if err != nil {
return errors . Wrap ( err , "RecordVPSUsage(): " )
}
bytez := bytez
actions = append ( actions , func ( ) taskResult {
result , err := model . DB . Exec (
` UPDATE tenant_vps_instance SET bytes = bytes + $ 1
WHERE service_provider = $ 2 AND provider_instance_id = $ 3 AND tenant_id = $ 4 AND billing_year = $ 5 AND billing_month = $ 6 ; ` ,
bytez , instance . ServiceProvider , instance . ProviderInstanceId , tenantIdInt , billingYear , billingMonth ,
)
if err != nil {
return taskResult { Name : strconv . Itoa ( tenantIdInt ) , Err : err }
}
rowsAffected , err := result . RowsAffected ( )
if err != nil {
return taskResult { Name : strconv . Itoa ( tenantIdInt ) , Err : err }
}
if rowsAffected != 1 {
return taskResult {
Name : strconv . Itoa ( tenantIdInt ) ,
Err : errors . Errorf (
"tenant_vps_instance row not found for vps '%s' tenant '%d'" ,
instance . GetId ( ) , tenantIdInt ,
) ,
}
}
return taskResult { Name : strconv . Itoa ( tenantIdInt ) }
} )
}
results := doInParallel ( false , actions ... )
errorStrings := [ ] string { }
for tenantId , result := range results {
if result . Err != nil {
errorStrings = append ( errorStrings , fmt . Sprintf ( "tenant %s: %+v" , tenantId , result . Err ) )
}
}
if len ( errorStrings ) != 0 {
return errors . Errorf ( "RecordVPSUsage(): \n%s\n" , strings . Join ( errorStrings , "\n" ) )
}
return nil
}
func ( model * DBModel ) GetTenantUsageTotal ( tenantId int , billingYear , billingMonth int ) ( int64 , error ) {
rows , err := model . DB . Query ( `
SELECT bytes FROM tenant_vps_instance WHERE tenant_id = $ 1 AND billing_year = $ 2 AND billing_month = $ 3
` , tenantId , billingYear , billingMonth ,
)
if err != nil {
return 0 , errors . Wrap ( err , "GetTenantUsage(): " )
}
defer rows . Close ( )
var monthlyBytes int64
for rows . Next ( ) {
var bytes int64
err := rows . Scan ( & bytes )
if err != nil {
return 0 , errors . Wrap ( err , "GetTenantUsage(): " )
}
monthlyBytes += bytes
}
return monthlyBytes , nil
}
func ( model * DBModel ) GetTenantUsageMetrics ( tenantId int , start , end time . Time ) ( map [ time . Time ] int64 , error ) {
rows , err := model . DB . Query ( `
SELECT measured , bytes FROM tenant_metrics_bandwidth WHERE tenant_id = $ 1 AND measured > $ 2 AND measured < $ 3
` , tenantId , start , end ,
)
if err != nil {
return nil , errors . Wrap ( err , "GetTenantUsageMetrics(): " )
}
defer rows . Close ( )
toReturn := map [ time . Time ] int64 { }
for rows . Next ( ) {
var measured time . Time
var bytes int64
err := rows . Scan ( & measured , & bytes )
if err != nil {
return nil , errors . Wrap ( err , "GetTenantUsageMetrics(): " )
}
toReturn [ measured ] = bytes
}
return toReturn , nil
}
func ( model * DBModel ) SaveInstanceConfiguration (
billingYear int ,
billingMonth int ,
instance * VPSInstance ,
config map [ int ] * TunnelSettings ,
) error {
// first we set shadow_config & active=true for all tenants mentioned in the config
tenantIds := [ ] int { }
for tenantId , tunnelSettings := range config {
tenantIds = append ( tenantIds , tenantId )
shadowConfigBytes , err := json . Marshal ( tunnelSettings )
if err != nil {
return errors . Wrapf ( err , "cant serialize shadow config for tenant %s on %s" , tenantId , instance . GetId ( ) )
}
shadowConfig := string ( shadowConfigBytes )
_ , err = model . DB . Exec ( `
INSERT INTO tenant_vps_instance (
billing_year , billing_month , tenant_id , service_provider , provider_instance_id ,
shadow_config , bytes , active
)
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 , $ 5 ,
$ 6 , 0 , $ 7 )
ON CONFLICT ON CONSTRAINT pk_tenant_vps_instance
DO
UPDATE SET shadow_config = $ 6 , active = $ 7 ;
` ,
billingYear , billingMonth , tenantId , instance . ServiceProvider , instance . ProviderInstanceId ,
shadowConfig , true ,
)
if err != nil {
return errors . Wrap ( err , "SaveInstanceConfiguration(): " )
}
}
// next, we disable all existing tenants for this instance which are not mentioned in the config
tenantIdsStrings := make ( [ ] string , len ( tenantIds ) )
for i , id := range tenantIds {
tenantIdsStrings [ i ] = strconv . Itoa ( id )
}
_ , err := model . DB . Exec (
fmt . Sprintf ( `
UPDATE tenant_vps_instance SET active = FALSE , deactivated_at = NOW ( )
WHERE billing_year = $ 1 AND billing_month = $ 2 AND service_provider = $ 3 AND provider_instance_id = $ 4
AND tenant_id NOT IN ( % s )
` , strings . Join ( tenantIdsStrings , ", " ) ) ,
billingYear , billingMonth , instance . ServiceProvider , instance . ProviderInstanceId ,
)
return errors . Wrap ( err , "SaveInstanceConfiguration(): " )
}
func ( model * DBModel ) PollScheduledTask ( name string , every time . Duration ) ( bool , error ) {
rows , err := model . DB . Query ( ` SELECT last_started, last_succeeded FROM scheduled_tasks WHERE name = $1 ` , name )
if err != nil {
return false , errors . Wrap ( err , "PollScheduledTask(): " )
}
defer rows . Close ( )
if rows . Next ( ) {
var lastStarted time . Time
var lastSucceeded time . Time
err := rows . Scan ( & lastStarted , & lastSucceeded )
//log.Printf("st %s, sc %s, since: %s\n", lastStarted, lastSucceeded, time.Since(lastSucceeded))
if err != nil {
return false , errors . Wrap ( err , "PollScheduledTask(): " )
}
if time . Since ( lastSucceeded ) > every {
_ , err := model . DB . Exec ( "UPDATE scheduled_tasks SET last_started = $1 WHERE name = $2" , time . Now ( ) . UTC ( ) , name )
if err != nil {
return false , errors . Wrap ( err , "PollScheduledTask(): " )
}
return true , nil
}
return false , nil
} else {
unixEpoch := time . Date ( 1970 , 1 , 1 , 0 , 0 , 0 , 1 , time . Now ( ) . UTC ( ) . Location ( ) )
_ , err := model . DB . Exec ( "INSERT INTO scheduled_tasks (name, last_succeeded) VALUES ($1, $2)" , name , unixEpoch )
if err != nil {
return false , errors . Wrap ( err , "PollScheduledTask(): " )
}
return true , nil
}
}
func ( model * DBModel ) ScheduledTaskCompleted ( name string ) error {
_ , err := model . DB . Exec ( "UPDATE scheduled_tasks SET last_succeeded = $1 WHERE name = $2" , time . Now ( ) . UTC ( ) , name )
if err != nil {
return errors . Wrap ( err , "PollScheduledTask(): " )
}
return nil
}
func ( model * DBModel ) PutKeyPair ( caName , name string , key , cert [ ] byte ) error {
_ , err := model . DB . Exec ( `
INSERT INTO pki_key_pairs ( ca_name , name , key_bytes , cert_bytes )
VALUES ( $ 1 , $ 2 , $ 3 , $ 4 )
ON CONFLICT ON CONSTRAINT pk_pki_key_pairs
DO
UPDATE SET key_bytes = $ 3 , cert_bytes = $ 4 ;
` ,
caName , name , key , cert ,
)
return errors . Wrap ( err , "PutKeyPair(): " )
}
func ( model * DBModel ) GetServerKeyPair ( caName , name string ) ( [ ] byte , [ ] byte , error ) {
var key , cert [ ] byte
rows , err := model . DB . Query ( `
SELECT key_bytes , cert_bytes FROM pki_key_pairs
WHERE ca_name = $ 1 AND name = $ 2
` ,
caName , name ,
)
if err != nil {
return nil , nil , errors . Wrap ( err , "GetServerKeyPair(): " )
}
defer rows . Close ( )
for rows . Next ( ) {
err = rows . Scan ( & key , & cert )
if err != nil {
return nil , nil , errors . Wrap ( err , "GetServerKeyPair(): " )
}
return key , cert , nil
}
return nil , nil , nil
}
func ( model * DBModel ) rowToVPSInstance ( row RowScanner ) ( * VPSInstance , error ) {
var serviceProvider string
var providerInstanceId string
var tenantId * int
var ipv4 string
var ipv6 string
var bytesMonthly int64
var created time . Time
var deprecated bool
var deleted bool
err := row . Scan ( & serviceProvider , & providerInstanceId , & tenantId , & ipv4 , & ipv6 , & bytesMonthly , & created , & deprecated , & deleted )
if err != nil {
return nil , errors . Wrap ( err , "rowToVPSInstance(): " )
}
tenantIdInt := 0
if tenantId != nil {
tenantIdInt = * tenantId
}
return & VPSInstance {
ServiceProvider : serviceProvider ,
ProviderInstanceId : providerInstanceId ,
TenantId : tenantIdInt ,
IPV4 : ipv4 ,
IPV6 : ipv6 ,
BytesMonthly : bytesMonthly ,
Created : created ,
Deprecated : deprecated ,
Deleted : deleted ,
} , nil
}
func domainNamesToString ( slice [ ] string ) string {
sort . Strings ( slice )
return strings . Join ( slice , "," )
}