🌱🏠 a cloud service to enable your own server (owned by you and running on your computer) to be accessible on the internet in seconds, no credit card required https://greenhouse.server.garden/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

1018 lines
29 KiB

package main
import (
"database/sql"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"time"
errors "git.sequentialread.com/forest/pkg-errors"
_ "github.com/lib/pq"
)
type RowScanner interface {
Scan(src ...interface{}) error
}
type DBModel struct {
DB *sql.DB
subdomainRegex *regexp.Regexp
PortRangeSize int
MinListeningPort int
MaxListeningPort int
}
type TunnelSettings struct {
PortStart int
PortEnd int
AuthorizedDomains []string
}
func (settings *TunnelSettings) DeepEquals(other *TunnelSettings) bool {
if domainNamesToString(settings.AuthorizedDomains) != domainNamesToString(other.AuthorizedDomains) {
return false
}
if settings.PortStart != other.PortStart || settings.PortEnd != other.PortEnd {
return false
}
return true
}
type TenantInfo struct {
Id int
Created time.Time
Email string
Subdomain string
DedicatedVPSCount int
Bytes int64
SMSAlarmNumber string
ServiceLimitCents int
BillingAlarmCents int
PortBucket int
TunnelSettings *TunnelSettings
Deactivated bool
APITokens []APIToken
}
type TenantVPSInstance struct {
TenantId int
ServiceProvider string
ProviderInstanceId string
ShadowConfig *TunnelSettings
Bytes int64
Active bool
DeactivatedAt *time.Time
}
type APIToken struct {
Name string
Active bool
HashedToken string
Created time.Time
LastUsed time.Time
}
const DomainVerificationPollingInterval = time.Hour
func (i *TenantVPSInstance) GetVPSInstanceId() string {
return fmt.Sprintf("%s-%s", i.ServiceProvider, i.ProviderInstanceId)
}
func initDatabase(config *Config) *DBModel {
desiredSchemaVersion := 2
db, err := sql.Open(config.DatabaseType, config.DatabaseConnectionString)
if err != nil {
log.Fatal(err)
}
if err := db.Ping(); err != nil {
log.Fatalf("failed to open database connection: %+v", err)
}
var tableName string
err = db.QueryRow(
"SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
config.DatabaseSchema, "schema_version",
).Scan(&tableName)
if err == sql.ErrNoRows {
_, err := db.Exec(`
CREATE TABLE schema_version (
version INT PRIMARY KEY NOT NULL
);
INSERT INTO schema_version(version) VALUES (1);
`)
if err != nil {
log.Fatalf("failed to create schema_version table: %+v", err)
}
} else if err != nil {
log.Fatalf("failed to check whether or not schema_version table exists: %+v", err)
}
readSchemaVersionFromDatabase := func() int {
var currentSchemaVersion int
err = db.QueryRow("SELECT version FROM schema_version").Scan(&currentSchemaVersion)
if err != nil {
log.Fatalf("failed to select currentSchemaVersion: %+v", err)
}
return currentSchemaVersion
}
currentSchemaVersion := readSchemaVersionFromDatabase()
files, err := ioutil.ReadDir("schema_versions")
if err != nil {
log.Fatalf("failed to list schema_versions: %+v", err)
}
getMigrationScript := func(version int, direction string) (filename string, content string) {
prefix := fmt.Sprintf("%02d_%s_", version, direction)
for _, file := range files {
if !file.IsDir() && strings.HasPrefix(file.Name(), prefix) && strings.HasSuffix(file.Name(), ".sql") {
filename = filepath.Join("schema_versions", file.Name())
contentsBytes, err := ioutil.ReadFile(filename)
if err != nil {
log.Fatalf("failed to read file '%s': %+v", filename, err)
}
content = fmt.Sprintf(`
BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE;
%s
UPDATE schema_version SET version = %d;
COMMIT TRANSACTION;
`, string(contentsBytes), version)
break
}
}
if content == "" {
log.Fatalf("didn't find any files in schema_versions matching %s*.sql", prefix)
}
return filename, content
}
for currentSchemaVersion != desiredSchemaVersion {
log.Printf(
"currentSchemaVersion (%d) != desiredSchemaVersion (%d), running database schema migration(s)\n",
currentSchemaVersion, desiredSchemaVersion,
)
var expectedSchemaVersion int
var filename, content string
if currentSchemaVersion < desiredSchemaVersion {
expectedSchemaVersion = currentSchemaVersion + 1
filename, content = getMigrationScript(expectedSchemaVersion, "up")
_, err = db.Exec(content)
if err != nil {
log.Fatalf("failed to execute database migration script %s: %+v", filename, err)
}
} else {
expectedSchemaVersion = currentSchemaVersion - 1
filename, content = getMigrationScript(currentSchemaVersion, "down")
_, err = db.Exec(content)
if err != nil {
log.Fatalf("failed to execute database migration script %s: %+v", filename, err)
}
}
actualSchemaVersion := readSchemaVersionFromDatabase()
if expectedSchemaVersion != actualSchemaVersion {
log.Fatalf(
"expecting database schema version (%d) to be %d after running database migration script %s",
actualSchemaVersion, expectedSchemaVersion, filename,
)
}
currentSchemaVersion = actualSchemaVersion
}
return &DBModel{
DB: db,
subdomainRegex: regexp.MustCompile("^[a-z0-9]([a-z0-9-_]*[a-z0-9]+)?$"),
// TODO make these configurable?
PortRangeSize: 20,
MinListeningPort: 10000,
MaxListeningPort: 30000,
}
}
// ---------------- DBModel methods ----------------
func (model *DBModel) Register(email, hashedPassword string) (int, error) {
var existingAccount string
err := model.DB.QueryRow("SELECT id FROM tenants WHERE email = $1", strings.ToLower(email)).Scan(&existingAccount)
if err != sql.ErrNoRows {
return 0, fmt.Errorf("email address '%s' is already associated with an account", strings.ToLower(email))
}
var inserted int
err = model.DB.QueryRow(
"INSERT INTO tenants (email, hashed_password) VALUES ($1, $2) RETURNING id",
strings.ToLower(email), hashedPassword,
).Scan(&inserted)
if err != nil {
return 0, errors.Wrap(err, "Register(): could not insert row into tenants table")
}
return int(inserted), nil
}
func (model *DBModel) CreateEmailVerificationToken(token string, tenantId int, expires time.Time) error {
//log.Printf("CreateEmailVerificationToken(): \n%s \n%d\n", expires.UTC().String(), expires.Unix())
_, err := model.DB.Exec(
"INSERT INTO email_verification_tokens (token, tenant_id, expires) VALUES ($1, $2, $3)",
token, tenantId, expires.UTC(),
)
if err != nil {
return errors.Wrap(err, "CreateEmailVerificationToken(): could not insert row into email_verification_tokens table")
}
return nil
}
func (model *DBModel) VerifyEmail(token string, tenantId int) error {
var expires time.Time
err := model.DB.QueryRow(
"SELECT expires FROM email_verification_tokens WHERE token = $1 AND tenant_id = $2",
token, tenantId,
).Scan(&expires)
// log.Println("VerifyEmail():")
// log.Println(expires)
// log.Println(expires.Unix())
if err != nil && err != sql.ErrNoRows {
log.Printf("VerifyEmail(): query error %+v", err)
}
if err != nil || time.Now().After(expires) {
return errors.New("email verification token was invalid or expired")
} else {
model.DB.Exec("DELETE FROM email_verification_tokens WHERE token = $1", token)
_, err := model.DB.Exec("UPDATE tenants SET email_verified = TRUE WHERE id = $1", tenantId)
if err != nil {
return errors.New("internal error occurred during email verification")
}
}
return nil
}
func (model *DBModel) GetLoginInfo(email string) (int, string, bool) {
tenantId := 0
var hashedPassword string
var emailVerified bool
err := model.DB.QueryRow(
"SELECT id, hashed_password, email_verified FROM tenants WHERE email = $1",
strings.ToLower(email),
).Scan(&tenantId, &hashedPassword, &emailVerified)
if tenantId != 0 && err == nil {
return tenantId, hashedPassword, emailVerified
} else {
return tenantId, "", false
}
}
func (model *DBModel) GetSession(id string, cameFromLaxCookie bool) (*Session, error) {
var loggedInTenantId int
var emailVerified bool
var email string
var expires time.Time
requireLaxCookie := ""
if cameFromLaxCookie {
requireLaxCookie = "and tenants.lax_cookie = TRUE"
}
err := model.DB.QueryRow(
fmt.Sprintf(`
SELECT session_cookies.tenant_id, tenants.email, tenants.email_verified, session_cookies.expires
FROM session_cookies JOIN tenants on session_cookies.tenant_id = tenants.id
WHERE session_cookies.id = $1 %s`,
requireLaxCookie,
),
id,
).Scan(&loggedInTenantId, &email, &emailVerified, &expires)
// log.Println("GetSession():")
// log.Println(expires.UTC())
// log.Println(expires.Unix())
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, errors.Wrapf(err, "GetSession(id=%s, cameFromLaxCookie=%t): ", id, cameFromLaxCookie)
}
return &Session{
TenantId: loggedInTenantId,
Email: email,
EmailVerified: emailVerified,
Expires: expires.UTC(),
}, nil
}
func (model *DBModel) SetSession(id string, session *Session) error {
_, err := model.DB.Exec("UPDATE tenants SET lax_cookie = $1 WHERE id = $2", session.LaxCookie, session.TenantId)
if err != nil {
return errors.Wrap(err, "SetSession(): ")
}
_, err = model.DB.Exec("DELETE FROM session_cookies WHERE tenant_id = $1", session.TenantId)
if err != nil {
return errors.Wrap(err, "SetSession(): ")
}
_, err = model.DB.Exec("INSERT INTO session_cookies (id, tenant_id, expires) VALUES ($1, $2, $3)", id, session.TenantId, session.Expires.UTC())
return errors.Wrap(err, "SetSession(): ")
}
func (model *DBModel) LogoutTenant(tenantId int) error {
_, err := model.DB.Exec("DELETE FROM session_cookies WHERE tenant_id = $1", tenantId)
return err
}
func (model *DBModel) GetUserByAPIToken(hashedApiToken string) (*Session, error) {
var loggedInTenantId int
var email string
var emailVerified bool
if len(hashedApiToken) < 8 {
return nil, errors.New("The given hashedApiToken token was too short")
}
err := model.DB.QueryRow(
`SELECT tenant_id FROM api_tokens WHERE hashed_token = $1 AND active = TRUE`,
hashedApiToken,
).Scan(&loggedInTenantId)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, errors.Wrapf(err, "GetUserByAPIToken(hashedApiToken=%s): ", hashedApiToken)
}
err = model.DB.QueryRow(
`SELECT email, email_verified FROM tenants WHERE id = $1`, loggedInTenantId,
).Scan(&email, &emailVerified)
if err != nil {
return nil, errors.Wrapf(err, "GetUserByAPIToken(hashedApiToken=%s): ", hashedApiToken)
}
return &Session{
TenantId: loggedInTenantId,
Email: email,
EmailVerified: emailVerified,
}, nil
}
func (model *DBModel) SetFreeSubdomain(tenantId int, subdomain string) (bool, error) {
subdomain = strings.ToLower(subdomain)
if !model.subdomainRegex.MatchString(subdomain) {
return false, errors.Errorf("SetFreeSubdomain(): subdomain '%s' is invalid", subdomain)
}
rows, err := model.DB.Query(`SELECT 1 FROM tenants WHERE subdomain = $1`, subdomain)
if err != nil {
return false, errors.Wrap(err, "SetFreeSubdomain(): ")
}
if rows.Next() {
return true, nil
}
_, err = model.DB.Exec("UPDATE tenants SET subdomain = $1 WHERE id = $2", subdomain, tenantId)
if err != nil {
return false, errors.Wrap(err, "SetFreeSubdomain(): ")
}
return false, nil
}
func (model *DBModel) SetReservedPorts(tenantId, portStart, portEnd, portBucket int) error {
_, err := model.DB.Exec(
"UPDATE tenants SET port_start = $1, port_end = $2, port_bucket = $3 WHERE id = $4",
portStart, portEnd, portBucket, tenantId,
)
if err != nil {
return errors.Wrap(err, "SetReservedPorts(): ")
}
return nil
}
func (model *DBModel) CreateAPIToken(tenantId int, keyName, hashedAPIToken string) error {
_, err := model.DB.Exec(
"INSERT INTO api_tokens (tenant_id, key_name, hashed_token) VALUES ($1, $2, $3)",
tenantId, keyName, hashedAPIToken,
)
if err != nil {
return errors.Wrap(err, "CreateAPIToken(): ")
}
return nil
}
func (model *DBModel) SetAPITokenActive(tenantId int, keyName string, active bool) error {
_, err := model.DB.Exec(
"UPDATE api_tokens SET active = $1 WHERE tenant_id = $2 AND keyName = $3",
active, tenantId, keyName,
)
if err != nil {
return errors.Wrap(err, "SetAPITokenActive(): ")
}
return nil
}
func (model *DBModel) DeleteAPIToken(tenantId int, keyName string) error {
_, err := model.DB.Exec(
"DELETE FROM api_tokens WHERE tenant_id = $1 AND keyName = $2",
tenantId, keyName,
)
if err != nil {
return errors.Wrap(err, "DeleteAPIToken(): ")
}
return nil
}
func (model *DBModel) GetNextReservedPorts() (int, int, int, error) {
port := 0
bucket := 0
err := model.DB.QueryRow("SELECT port, bucket FROM reserved_ports_counter").Scan(&port, &bucket)
if err != nil {
return 0, -1, -1, errors.Wrap(err, "GetNextReservedPorts(): ")
}
if port+model.PortRangeSize >= model.MaxListeningPort {
bucket++
port = model.MinListeningPort
}
_, err = model.DB.Exec("UPDATE reserved_ports_counter SET port = $1, bucket = $2", port+model.PortRangeSize, bucket)
if err != nil {
return 0, -1, -1, errors.Wrap(err, "GetNextReservedPorts(): ")
}
return port, port + (model.PortRangeSize - 1), bucket, nil
}
func (model *DBModel) GetVPSInstances() (map[string]*VPSInstance, error) {
rows, err := model.DB.Query(`
SELECT service_provider, provider_instance_id, tenant_id, ipv4, ipv6, bytes_monthly, created, deprecated, deleted
FROM vps_instances WHERE deleted = FALSE
`,
)
if err != nil {
return nil, errors.Wrap(err, "GetVPSInstances(): ")
}
toReturn := map[string]*VPSInstance{}
for rows.Next() {
instance, err := model.rowToVPSInstance(rows)
if err != nil {
return nil, errors.Wrap(err, "GetVPSInstances(): ")
}
toReturn[instance.GetId()] = instance
}
return toReturn, nil
}
func (model *DBModel) CreateVPSInstance(toCreate *VPSInstance) error {
_, err := model.DB.Exec(`
INSERT INTO vps_instances (
service_provider, provider_instance_id, created,
ipv4, ipv6, bytes_monthly
)
VALUES($1, $2, $3,
$4, $5, $6)
`, toCreate.ServiceProvider, toCreate.ProviderInstanceId, toCreate.Created,
toCreate.IPV4, toCreate.IPV6, toCreate.BytesMonthly,
)
if err != nil {
return errors.Wrap(err, "CreateVPSInstance(): ")
}
if toCreate.TenantId != 0 {
_, err := model.DB.Exec(
`UPDATE vps_instances SET tenantId = $1 WHERE service_provider = $2 AND provider_instance_id = $3`,
toCreate.TenantId, toCreate.ServiceProvider, toCreate.ProviderInstanceId,
)
if err != nil {
return errors.Wrap(err, "CreateVPSInstance(): ")
}
}
return nil
}
func (model *DBModel) DeleteVPSInstance(provider, providerInstanceId string) error {
result, err := model.DB.Exec(
`UPDATE vps_instances SET deleted = TRUE WHERE service_provider = $1 AND provider_instance_id = $2`,
provider, providerInstanceId,
)
if err != nil {
return errors.Wrap(err, "DeleteVPSInstance(): ")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "DeleteVPSInstance(): ")
}
if rowsAffected == 0 {
return errors.Errorf("DeleteVPSInstance(): '%s-%s' was not found", provider, providerInstanceId)
}
return nil
}
func (model *DBModel) GetTenants() (map[int]*TenantInfo, error) {
rows, err := model.DB.Query(`SELECT tenant_id, domain_name, last_verified FROM external_domains`)
if err != nil {
return nil, errors.Wrap(err, "GetTenants(): ")
}
verificationCutoff := time.Now().Add(-(DomainVerificationPollingInterval + time.Minute))
authorizedDomains := map[int][]string{}
for rows.Next() {
var tenantId int
var domainName string
var lastVerified time.Time
err := rows.Scan(&tenantId, &domainName, &lastVerified)
if err != nil {
return nil, errors.Wrap(err, "GetTenants(): ")
}
if lastVerified.After(verificationCutoff) {
if _, has := authorizedDomains[tenantId]; !has {
authorizedDomains[tenantId] = []string{domainName}
} else {
authorizedDomains[tenantId] = append(authorizedDomains[tenantId], domainName)
}
}
}
rows, err = model.DB.Query(`SELECT id, created, subdomain, service_limit_cents, port_start, port_end, port_bucket FROM tenants`)
if err != nil {
return nil, errors.Wrap(err, "GetTenants(): ")
}
toReturn := map[int]*TenantInfo{}
for rows.Next() {
var tenantId int
var tenantCreated time.Time
var subdomain *string
var serviceLimitCents int
var portStart int
var portEnd int
var portBucket int
err := rows.Scan(&tenantId, &tenantCreated, &subdomain, &serviceLimitCents, &portStart, &portEnd, &portBucket)
if err != nil {
return nil, errors.Wrap(err, "GetTenants(): ")
}
thisTenantDomains := authorizedDomains[tenantId]
if thisTenantDomains == nil {
thisTenantDomains = []string{}
}
// thisTenantPorts := reservedPorts[tenantId]
// if thisTenantPorts == nil {
// thisTenantPorts = []PortRange{}
// }
subdomainString := ""
if subdomain != nil {
subdomainString = *subdomain
thisTenantDomains = append(thisTenantDomains, fmt.Sprintf("%s.%s", subdomainString, freeSubdomainDomain))
}
toReturn[tenantId] = &TenantInfo{
Id: tenantId,
Created: tenantCreated,
Subdomain: subdomainString,
ServiceLimitCents: serviceLimitCents,
PortBucket: portBucket,
TunnelSettings: &TunnelSettings{
PortStart: portStart,
PortEnd: portEnd,
AuthorizedDomains: thisTenantDomains,
},
}
}
return toReturn, nil
}
func (model *DBModel) GetTenant(tenantId int) (*TenantInfo, error) {
rows, err := model.DB.Query(
`SELECT domain_name, last_verified FROM external_domains WHERE tenant_id = $1`,
tenantId,
)
if err != nil {
return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
}
verificationCutoff := time.Now().Add(-(DomainVerificationPollingInterval + time.Minute))
authorizedDomains := []string{}
for rows.Next() {
var domainName string
var lastVerified time.Time
err := rows.Scan(&domainName, &lastVerified)
if err != nil {
return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
}
if lastVerified.After(verificationCutoff) {
authorizedDomains = append(authorizedDomains, domainName)
}
}
rows, err = model.DB.Query(
`SELECT key_name, hashed_token, active, created, last_used FROM api_tokens WHERE tenant_id = $1`,
tenantId,
)
if err != nil {
return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
}
apiTokens := []APIToken{}
for rows.Next() {
var keyName string
var hashedToken string
var active bool
var created time.Time
var lastUsed time.Time
err := rows.Scan(&keyName, &hashedToken, &active, &created, &lastUsed)
if err != nil {
return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
}
apiTokens = append(apiTokens, APIToken{
Name: keyName,
HashedToken: hashedToken,
Active: active,
Created: created,
LastUsed: lastUsed,
})
}
var created time.Time
var subdomain *string
var email string
var smsAlarmNumber *string
var billingAlarmCents int
var serviceLimitCents int
var portStart int
var portEnd int
var portBucket int
err = model.DB.QueryRow(
`SELECT created, email, subdomain, sms_alarm_number, billing_alarm_cents, service_limit_cents,
port_start, port_end, port_bucket
FROM tenants WHERE id = $1`,
tenantId,
).Scan(
&created, &email, &subdomain, &smsAlarmNumber, &billingAlarmCents, &serviceLimitCents,
&portStart, &portEnd, &portBucket,
)
if err != nil {
return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
}
subdomainString := ""
if subdomain != nil {
subdomainString = *subdomain
authorizedDomains = append(authorizedDomains, fmt.Sprintf("%s.%s", subdomainString, freeSubdomainDomain))
}
smsString := ""
if smsAlarmNumber != nil {
smsString = *smsAlarmNumber
}
return &TenantInfo{
Id: tenantId,
Email: email,
Created: created,
SMSAlarmNumber: smsString,
Subdomain: subdomainString,
BillingAlarmCents: billingAlarmCents,
ServiceLimitCents: serviceLimitCents,
PortBucket: portBucket,
TunnelSettings: &TunnelSettings{
PortStart: portStart,
PortEnd: portEnd,
AuthorizedDomains: authorizedDomains,
},
APITokens: apiTokens,
}, nil
}
func (model *DBModel) GetTenantVPSInstanceRows(billingYear, billingMonth int) ([]*TenantVPSInstance, error) {
// tenantCondition := ""
// if tenantId > 0 {
// tenantCondition = "AND tenant_vps_instance.tenant_id = $3"
// }
rows, err := model.DB.Query(`
SELECT
tenant_id,
service_provider,
provider_instance_id,
shadow_config,
bytes,
active,
deactivated_at
FROM tenant_vps_instance
WHERE billing_year = $1 AND billing_month = $2
`, billingYear, billingMonth,
)
if err != nil {
return nil, errors.Wrap(err, "GetTenantVPSInstanceRows(): ")
}
toReturn := []*TenantVPSInstance{}
for rows.Next() {
var tenantId int
var serviceProvider string
var serviceProviderInstanceId string
var shadowConfigString string
var bytes int64
var active bool
var deactivatedAt *time.Time
err := rows.Scan(&tenantId, &serviceProvider, &serviceProviderInstanceId, &shadowConfigString, &bytes, &active, &deactivatedAt)
if err != nil {
return nil, errors.Wrap(err, "GetTenantVPSInstanceRows(): ")
}
var shadowConfig TunnelSettings
err = json.Unmarshal([]byte(shadowConfigString), &shadowConfig)
if err != nil {
return nil, errors.Wrap(err, "GetTenantVPSInstanceRows(): ")
}
toReturn = append(toReturn, &TenantVPSInstance{
TenantId: tenantId,
ServiceProvider: serviceProvider,
ProviderInstanceId: serviceProviderInstanceId,
ShadowConfig: &shadowConfig,
Bytes: bytes,
Active: active,
DeactivatedAt: deactivatedAt,
})
}
return toReturn, nil
}
func (model *DBModel) RecordTenantsUsage(usage map[int]int64) error {
actions := []func() taskResult{}
for tenantId, bytez := range usage {
tenantId := tenantId
bytez := bytez
actions = append(actions, func() taskResult {
_, err := model.DB.Exec(`INSERT INTO tenant_metrics_bandwidth (tenant_id, bytes) VALUES ($1, $2)`, tenantId, bytez)
return taskResult{Name: strconv.Itoa(tenantId), Err: err}
})
}
results := doInParallel(false, actions...)
errorStrings := []string{}
for tenantId, result := range results {
if result.Err != nil {
errorStrings = append(errorStrings, fmt.Sprintf("tenant %s: %+v", tenantId, result.Err))
}
}
if len(errorStrings) != 0 {
return errors.Errorf("RecordTenantsUsage(): \n%s\n", strings.Join(errorStrings, "\n"))
}
return nil
}
func (model *DBModel) RecordVPSUsage(instance *VPSInstance, usage ThresholdMetrics, billingYear int, billingMonth int) error {
bytesByTenant := map[string]int64{}
for k, v := range usage.InboundByTenant {
bytesByTenant[k] += v
}
for k, v := range usage.OutboundByTenant {
bytesByTenant[k] += v
}
actions := []func() taskResult{}
for tenantIdString, bytez := range bytesByTenant {
tenantIdInt, err := strconv.Atoi(tenantIdString)
if err != nil {
return errors.Wrap(err, "RecordVPSUsage(): ")
}
bytez := bytez
actions = append(actions, func() taskResult {
result, err := model.DB.Exec(
`UPDATE tenant_vps_instance SET bytes = bytes + $1
WHERE service_provider = $2 AND provider_instance_id = $3 AND tenant_id = $4 AND billing_year = $5 AND billing_month = $6;`,
bytez, instance.ServiceProvider, instance.ProviderInstanceId, tenantIdInt, billingYear, billingMonth,
)
if err != nil {
return taskResult{Name: strconv.Itoa(tenantIdInt), Err: err}
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return taskResult{Name: strconv.Itoa(tenantIdInt), Err: err}
}
if rowsAffected != 1 {
return taskResult{
Name: strconv.Itoa(tenantIdInt),
Err: errors.Errorf(
"tenant_vps_instance row not found for vps '%s' tenant '%d'",
instance.GetId(), tenantIdInt,
),
}
}
return taskResult{Name: strconv.Itoa(tenantIdInt)}
})
}
results := doInParallel(false, actions...)
errorStrings := []string{}
for tenantId, result := range results {
if result.Err != nil {
errorStrings = append(errorStrings, fmt.Sprintf("tenant %s: %+v", tenantId, result.Err))
}
}
if len(errorStrings) != 0 {
return errors.Errorf("RecordVPSUsage(): \n%s\n", strings.Join(errorStrings, "\n"))
}
return nil
}
func (model *DBModel) GetTenantUsageTotal(tenantId int, billingYear, billingMonth int) (int64, error) {
rows, err := model.DB.Query(`
SELECT bytes FROM tenant_vps_instance WHERE tenant_id = $1 AND billing_year = $2 AND billing_month = $3
`, tenantId, billingYear, billingMonth,
)
if err != nil {
return 0, errors.Wrap(err, "GetTenantUsage(): ")
}
var monthlyBytes int64
for rows.Next() {
var bytes int64
err := rows.Scan(&bytes)
if err != nil {
return 0, errors.Wrap(err, "GetTenantUsage(): ")
}
monthlyBytes += bytes
}
return monthlyBytes, nil
}
func (model *DBModel) GetTenantUsageMetrics(tenantId int, start, end time.Time) (map[time.Time]int64, error) {
rows, err := model.DB.Query(`
SELECT measured, bytes FROM tenant_metrics_bandwidth WHERE tenant_id = $1 AND measured > $2 AND measured < $3
`, tenantId, start, end,
)
if err != nil {
return nil, errors.Wrap(err, "GetTenantUsageMetrics(): ")
}
toReturn := map[time.Time]int64{}
for rows.Next() {
var measured time.Time
var bytes int64
err := rows.Scan(&measured, &bytes)
if err != nil {
return nil, errors.Wrap(err, "GetTenantUsageMetrics(): ")
}
toReturn[measured] = bytes
}
return toReturn, nil
}
func (model *DBModel) SaveInstanceConfiguration(
billingYear int,
billingMonth int,
instance *VPSInstance,
config map[int]*TunnelSettings,
) error {
// first we set shadow_config & active=true for all tenants mentioned in the config
tenantIds := []int{}
for tenantId, tunnelSettings := range config {
tenantIds = append(tenantIds, tenantId)
shadowConfigBytes, err := json.Marshal(tunnelSettings)
if err != nil {
return errors.Wrapf(err, "cant serialize shadow config for tenant %s on %s", tenantId, instance.GetId())
}
shadowConfig := string(shadowConfigBytes)
_, err = model.DB.Exec(`
INSERT INTO tenant_vps_instance (
billing_year, billing_month, tenant_id, service_provider, provider_instance_id,
shadow_config, bytes, active
)
VALUES($1, $2, $3, $4, $5,
$6, 0, $7)
ON CONFLICT ON CONSTRAINT pk_tenant_vps_instance
DO
UPDATE SET shadow_config = $6, active = $7;
`,
billingYear, billingMonth, tenantId, instance.ServiceProvider, instance.ProviderInstanceId,
shadowConfig, true,
)
if err != nil {
return errors.Wrap(err, "SaveInstanceConfiguration(): ")
}
}
// next, we disable all existing tenants for this instance which are not mentioned in the config
tenantIdsStrings := make([]string, len(tenantIds))
for i, id := range tenantIds {
tenantIdsStrings[i] = strconv.Itoa(id)
}
_, err := model.DB.Exec(
fmt.Sprintf(`
UPDATE tenant_vps_instance SET active = FALSE, deactivated_at = NOW()
WHERE billing_year = $1 AND billing_month = $2 AND service_provider = $3 AND provider_instance_id = $4
AND tenant_id NOT IN (%s)
`, strings.Join(tenantIdsStrings, ", ")),
billingYear, billingMonth, instance.ServiceProvider, instance.ProviderInstanceId,
)
return errors.Wrap(err, "SaveInstanceConfiguration(): ")
}
func (model *DBModel) PutKeyPair(caName, name string, key, cert []byte) error {
_, err := model.DB.Exec(`
INSERT INTO pki_key_pairs (ca_name, name, key_bytes, cert_bytes)
VALUES($1, $2, $3, $4)
ON CONFLICT ON CONSTRAINT pk_pki_key_pairs
DO
UPDATE SET key_bytes = $3, cert_bytes = $4;
`,
caName, name, key, cert,
)
return errors.Wrap(err, "PutKeyPair(): ")
}
func (model *DBModel) GetServerKeyPair(caName, name string) ([]byte, []byte, error) {
var key, cert []byte
rows, err := model.DB.Query(`
SELECT key_bytes, cert_bytes FROM pki_key_pairs
WHERE ca_name = $1 AND name = $2
`,
caName, name,
)
if err != nil {
return nil, nil, errors.Wrap(err, "GetServerKeyPair(): ")
}
for rows.Next() {
err = rows.Scan(&key, &cert)
if err != nil {
return nil, nil, errors.Wrap(err, "GetServerKeyPair(): ")
}
return key, cert, nil
}
return nil, nil, nil
}
func (model *DBModel) rowToVPSInstance(row RowScanner) (*VPSInstance, error) {
var serviceProvider string
var providerInstanceId string
var tenantId *int
var ipv4 string
var ipv6 string
var bytesMonthly int64
var created time.Time
var deprecated bool
var deleted bool
err := row.Scan(&serviceProvider, &providerInstanceId, &tenantId, &ipv4, &ipv6, &bytesMonthly, &created, &deprecated, &deleted)
if err != nil {
return nil, errors.Wrap(err, "rowToVPSInstance(): ")
}
tenantIdInt := 0
if tenantId != nil {
tenantIdInt = *tenantId
}
return &VPSInstance{
ServiceProvider: serviceProvider,
ProviderInstanceId: providerInstanceId,
TenantId: tenantIdInt,
IPV4: ipv4,
IPV6: ipv6,
BytesMonthly: bytesMonthly,
Created: created,
Deprecated: deprecated,
Deleted: deleted,
}, nil
}
func domainNamesToString(slice []string) string {
sort.Strings(slice)
return strings.Join(slice, ",")
}