🌱🏠 a cloud service to enable your own server (owned by you and running on your computer) to be accessible on the internet in seconds, no credit card required https://greenhouse.server.garden/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1148 lines
33 KiB

9 months ago
9 months ago
8 months ago
8 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
9 months ago
8 months ago
8 months ago
  1. package main
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "io/ioutil"
  7. "log"
  8. "path/filepath"
  9. "regexp"
  10. "sort"
  11. "strconv"
  12. "strings"
  13. "time"
  14. errors "git.sequentialread.com/forest/pkg-errors"
  15. _ "github.com/lib/pq"
  16. )
  17. type RowScanner interface {
  18. Scan(src ...interface{}) error
  19. }
  20. type DBModel struct {
  21. DB *sql.DB
  22. subdomainRegex *regexp.Regexp
  23. PortRangeSize int
  24. MinListeningPort int
  25. MaxListeningPort int
  26. }
  27. type TunnelSettings struct {
  28. PortStart int
  29. PortEnd int
  30. AuthorizedDomains []string
  31. }
  32. func (settings *TunnelSettings) DeepEquals(other *TunnelSettings) bool {
  33. if domainNamesToString(settings.AuthorizedDomains) != domainNamesToString(other.AuthorizedDomains) {
  34. return false
  35. }
  36. if settings.PortStart != other.PortStart || settings.PortEnd != other.PortEnd {
  37. return false
  38. }
  39. return true
  40. }
  41. type TenantInfo struct {
  42. Id int
  43. Created time.Time
  44. Email string
  45. Subdomain string
  46. DedicatedVPSCount int
  47. Bytes int64
  48. SMSAlarmNumber string
  49. ServiceLimitCents int
  50. BillingAlarmCents int
  51. PortBucket int
  52. TunnelSettings *TunnelSettings
  53. Deactivated bool
  54. APITokens []APIToken
  55. ExternalDomains []ExternalDomain
  56. }
  57. type TenantVPSInstance struct {
  58. TenantId int
  59. ServiceProvider string
  60. ProviderInstanceId string
  61. ShadowConfig *TunnelSettings
  62. Bytes int64
  63. Active bool
  64. DeactivatedAt *time.Time
  65. }
  66. type APIToken struct {
  67. Name string
  68. Active bool
  69. HashedToken string
  70. Created time.Time
  71. LastUsed time.Time
  72. }
  73. type ExternalDomain struct {
  74. DomainName string
  75. IsValid bool
  76. }
  77. const DomainVerificationPollingInterval = time.Hour
  78. func (i *TenantVPSInstance) GetVPSInstanceId() string {
  79. return fmt.Sprintf("%s-%s", i.ServiceProvider, i.ProviderInstanceId)
  80. }
  81. func initDatabase(config *Config) *DBModel {
  82. desiredSchemaVersion := 3
  83. db, err := sql.Open(config.DatabaseType, config.DatabaseConnectionString)
  84. if err != nil {
  85. log.Fatal(err)
  86. }
  87. if err := db.Ping(); err != nil {
  88. log.Fatalf("failed to open database connection: %+v", err)
  89. }
  90. var tableName string
  91. err = db.QueryRow(
  92. "SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
  93. config.DatabaseSchema, "schema_version",
  94. ).Scan(&tableName)
  95. if err == sql.ErrNoRows {
  96. _, err := db.Exec(`
  97. CREATE TABLE schema_version (
  98. version INT PRIMARY KEY NOT NULL
  99. );
  100. INSERT INTO schema_version(version) VALUES (1);
  101. `)
  102. if err != nil {
  103. log.Fatalf("failed to create schema_version table: %+v", err)
  104. }
  105. } else if err != nil {
  106. log.Fatalf("failed to check whether or not schema_version table exists: %+v", err)
  107. }
  108. readSchemaVersionFromDatabase := func() int {
  109. var currentSchemaVersion int
  110. err = db.QueryRow("SELECT version FROM schema_version").Scan(&currentSchemaVersion)
  111. if err != nil {
  112. log.Fatalf("failed to select currentSchemaVersion: %+v", err)
  113. }
  114. return currentSchemaVersion
  115. }
  116. currentSchemaVersion := readSchemaVersionFromDatabase()
  117. files, err := ioutil.ReadDir("schema_versions")
  118. if err != nil {
  119. log.Fatalf("failed to list schema_versions: %+v", err)
  120. }
  121. getMigrationScript := func(version int, direction string) (filename string, content string) {
  122. prefix := fmt.Sprintf("%02d_%s_", version, direction)
  123. for _, file := range files {
  124. if !file.IsDir() && strings.HasPrefix(file.Name(), prefix) && strings.HasSuffix(file.Name(), ".sql") {
  125. filename = filepath.Join("schema_versions", file.Name())
  126. contentsBytes, err := ioutil.ReadFile(filename)
  127. if err != nil {
  128. log.Fatalf("failed to read file '%s': %+v", filename, err)
  129. }
  130. content = fmt.Sprintf(`
  131. BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE;
  132. %s
  133. UPDATE schema_version SET version = %d;
  134. COMMIT TRANSACTION;
  135. `, string(contentsBytes), version)
  136. break
  137. }
  138. }
  139. if content == "" {
  140. log.Fatalf("didn't find any files in schema_versions matching %s*.sql", prefix)
  141. }
  142. return filename, content
  143. }
  144. for currentSchemaVersion != desiredSchemaVersion {
  145. log.Printf(
  146. "currentSchemaVersion (%d) != desiredSchemaVersion (%d), running database schema migration(s)\n",
  147. currentSchemaVersion, desiredSchemaVersion,
  148. )
  149. var expectedSchemaVersion int
  150. var filename, content string
  151. if currentSchemaVersion < desiredSchemaVersion {
  152. expectedSchemaVersion = currentSchemaVersion + 1
  153. filename, content = getMigrationScript(expectedSchemaVersion, "up")
  154. _, err = db.Exec(content)
  155. if err != nil {
  156. log.Fatalf("failed to execute database migration script %s: %+v", filename, err)
  157. }
  158. } else {
  159. expectedSchemaVersion = currentSchemaVersion - 1
  160. filename, content = getMigrationScript(currentSchemaVersion, "down")
  161. _, err = db.Exec(content)
  162. if err != nil {
  163. log.Fatalf("failed to execute database migration script %s: %+v", filename, err)
  164. }
  165. }
  166. actualSchemaVersion := readSchemaVersionFromDatabase()
  167. if expectedSchemaVersion != actualSchemaVersion {
  168. log.Fatalf(
  169. "expecting database schema version (%d) to be %d after running database migration script %s",
  170. actualSchemaVersion, expectedSchemaVersion, filename,
  171. )
  172. }
  173. currentSchemaVersion = actualSchemaVersion
  174. }
  175. return &DBModel{
  176. DB: db,
  177. subdomainRegex: regexp.MustCompile("^[a-z0-9]([a-z0-9-_]*[a-z0-9]+)?$"),
  178. // TODO make these configurable?
  179. PortRangeSize: 20,
  180. MinListeningPort: 10000,
  181. MaxListeningPort: 30000,
  182. }
  183. }
  184. // ---------------- DBModel methods ----------------
  185. func (model *DBModel) Register(email, hashedPassword string) (int, error) {
  186. var existingAccount string
  187. err := model.DB.QueryRow("SELECT id FROM tenants WHERE email = $1", strings.ToLower(email)).Scan(&existingAccount)
  188. if err != sql.ErrNoRows {
  189. return 0, fmt.Errorf("email address '%s' is already associated with an account", strings.ToLower(email))
  190. }
  191. var inserted int
  192. err = model.DB.QueryRow(
  193. "INSERT INTO tenants (email, hashed_password) VALUES ($1, $2) RETURNING id",
  194. strings.ToLower(email), hashedPassword,
  195. ).Scan(&inserted)
  196. if err != nil {
  197. return 0, errors.Wrap(err, "Register(): could not insert row into tenants table")
  198. }
  199. return int(inserted), nil
  200. }
  201. func (model *DBModel) CreateEmailVerificationToken(token string, tenantId int, expires time.Time) error {
  202. //log.Printf("CreateEmailVerificationToken(): \n%s \n%d\n", expires.UTC().String(), expires.Unix())
  203. _, err := model.DB.Exec(
  204. "INSERT INTO email_verification_tokens (token, tenant_id, expires) VALUES ($1, $2, $3)",
  205. token, tenantId, expires.UTC(),
  206. )
  207. if err != nil {
  208. return errors.Wrap(err, "CreateEmailVerificationToken(): could not insert row into email_verification_tokens table")
  209. }
  210. return nil
  211. }
  212. func (model *DBModel) VerifyEmail(token string, tenantId int) error {
  213. var expires time.Time
  214. err := model.DB.QueryRow(
  215. "SELECT expires FROM email_verification_tokens WHERE token = $1 AND tenant_id = $2",
  216. token, tenantId,
  217. ).Scan(&expires)
  218. // log.Println("VerifyEmail():")
  219. // log.Println(expires)
  220. // log.Println(expires.Unix())
  221. if err != nil && err != sql.ErrNoRows {
  222. log.Printf("VerifyEmail(): query error %+v", err)
  223. }
  224. if err != nil || time.Now().UTC().After(expires) {
  225. return errors.New("email verification token was invalid or expired")
  226. } else {
  227. model.DB.Exec("DELETE FROM email_verification_tokens WHERE token = $1", token)
  228. _, err := model.DB.Exec("UPDATE tenants SET email_verified = TRUE WHERE id = $1", tenantId)
  229. if err != nil {
  230. return errors.New("internal error occurred during email verification")
  231. }
  232. }
  233. return nil
  234. }
  235. func (model *DBModel) GetLoginInfo(email string) (int, string, bool) {
  236. tenantId := 0
  237. var hashedPassword string
  238. var emailVerified bool
  239. err := model.DB.QueryRow(
  240. "SELECT id, hashed_password, email_verified FROM tenants WHERE email = $1",
  241. strings.ToLower(email),
  242. ).Scan(&tenantId, &hashedPassword, &emailVerified)
  243. if tenantId != 0 && err == nil {
  244. return tenantId, hashedPassword, emailVerified
  245. } else {
  246. return tenantId, "", false
  247. }
  248. }
  249. func (model *DBModel) GetSession(id string, cameFromLaxCookie bool) (*Session, error) {
  250. var loggedInTenantId int
  251. var emailVerified bool
  252. var email string
  253. var expires time.Time
  254. requireLaxCookie := ""
  255. if cameFromLaxCookie {
  256. requireLaxCookie = "and tenants.lax_cookie = TRUE"
  257. }
  258. err := model.DB.QueryRow(
  259. fmt.Sprintf(`
  260. SELECT session_cookies.tenant_id, tenants.email, tenants.email_verified, session_cookies.expires
  261. FROM session_cookies JOIN tenants on session_cookies.tenant_id = tenants.id
  262. WHERE session_cookies.id = $1 %s`,
  263. requireLaxCookie,
  264. ),
  265. id,
  266. ).Scan(&loggedInTenantId, &email, &emailVerified, &expires)
  267. // log.Println("GetSession():")
  268. // log.Println(expires.UTC())
  269. // log.Println(expires.Unix())
  270. if err == sql.ErrNoRows {
  271. return nil, nil
  272. }
  273. if err != nil {
  274. return nil, errors.Wrapf(err, "GetSession(id=%s, cameFromLaxCookie=%t): ", id, cameFromLaxCookie)
  275. }
  276. return &Session{
  277. TenantId: loggedInTenantId,
  278. Email: email,
  279. EmailVerified: emailVerified,
  280. Expires: expires.UTC(),
  281. }, nil
  282. }
  283. func (model *DBModel) SetSession(id string, session *Session) error {
  284. _, err := model.DB.Exec("UPDATE tenants SET lax_cookie = $1 WHERE id = $2", session.LaxCookie, session.TenantId)
  285. if err != nil {
  286. return errors.Wrap(err, "SetSession(): ")
  287. }
  288. _, err = model.DB.Exec("DELETE FROM session_cookies WHERE tenant_id = $1", session.TenantId)
  289. if err != nil {
  290. return errors.Wrap(err, "SetSession(): ")
  291. }
  292. _, err = model.DB.Exec("INSERT INTO session_cookies (id, tenant_id, expires) VALUES ($1, $2, $3)", id, session.TenantId, session.Expires.UTC())
  293. return errors.Wrap(err, "SetSession(): ")
  294. }
  295. func (model *DBModel) LogoutTenant(tenantId int) error {
  296. _, err := model.DB.Exec("DELETE FROM session_cookies WHERE tenant_id = $1", tenantId)
  297. return err
  298. }
  299. func (model *DBModel) GetUserByAPIToken(hashedApiToken string) (*Session, error) {
  300. var loggedInTenantId int
  301. var email string
  302. var emailVerified bool
  303. if len(hashedApiToken) < 8 {
  304. return nil, errors.New("The given hashedApiToken token was too short")
  305. }
  306. err := model.DB.QueryRow(
  307. `SELECT tenant_id FROM api_tokens WHERE hashed_token = $1 AND active = TRUE`,
  308. hashedApiToken,
  309. ).Scan(&loggedInTenantId)
  310. if err == sql.ErrNoRows {
  311. return nil, nil
  312. }
  313. if err != nil {
  314. return nil, errors.Wrapf(err, "GetUserByAPIToken(hashedApiToken=%s): ", hashedApiToken)
  315. }
  316. err = model.DB.QueryRow(
  317. `SELECT email, email_verified FROM tenants WHERE id = $1`, loggedInTenantId,
  318. ).Scan(&email, &emailVerified)
  319. if err != nil {
  320. return nil, errors.Wrapf(err, "GetUserByAPIToken(hashedApiToken=%s): ", hashedApiToken)
  321. }
  322. return &Session{
  323. TenantId: loggedInTenantId,
  324. Email: email,
  325. EmailVerified: emailVerified,
  326. }, nil
  327. }
  328. func (model *DBModel) SetFreeSubdomain(tenantId int, subdomain string) (bool, error) {
  329. subdomain = strings.ToLower(subdomain)
  330. if !model.subdomainRegex.MatchString(subdomain) {
  331. return false, errors.Errorf("SetFreeSubdomain(): subdomain '%s' is invalid", subdomain)
  332. }
  333. rows, err := model.DB.Query(`SELECT 1 FROM tenants WHERE subdomain = $1`, subdomain)
  334. if err != nil {
  335. return false, errors.Wrap(err, "SetFreeSubdomain(): ")
  336. }
  337. defer rows.Close()
  338. if rows.Next() {
  339. return true, nil
  340. }
  341. _, err = model.DB.Exec("UPDATE tenants SET subdomain = $1 WHERE id = $2", subdomain, tenantId)
  342. if err != nil {
  343. return false, errors.Wrap(err, "SetFreeSubdomain(): ")
  344. }
  345. return false, nil
  346. }
  347. func (model *DBModel) SetReservedPorts(tenantId, portStart, portEnd, portBucket int) error {
  348. _, err := model.DB.Exec(
  349. "UPDATE tenants SET port_start = $1, port_end = $2, port_bucket = $3 WHERE id = $4",
  350. portStart, portEnd, portBucket, tenantId,
  351. )
  352. if err != nil {
  353. return errors.Wrap(err, "SetReservedPorts(): ")
  354. }
  355. return nil
  356. }
  357. func (model *DBModel) CreateAPIToken(tenantId int, keyName, hashedAPIToken string) error {
  358. _, err := model.DB.Exec(
  359. "INSERT INTO api_tokens (tenant_id, key_name, hashed_token) VALUES ($1, $2, $3)",
  360. tenantId, keyName, hashedAPIToken,
  361. )
  362. if err != nil {
  363. return errors.Wrap(err, "CreateAPIToken(): ")
  364. }
  365. return nil
  366. }
  367. func (model *DBModel) SetAPITokenActive(tenantId int, keyName string, active bool) error {
  368. _, err := model.DB.Exec(
  369. "UPDATE api_tokens SET active = $1 WHERE tenant_id = $2 AND keyName = $3",
  370. active, tenantId, keyName,
  371. )
  372. if err != nil {
  373. return errors.Wrap(err, "SetAPITokenActive(): ")
  374. }
  375. return nil
  376. }
  377. func (model *DBModel) DeleteAPIToken(tenantId int, keyName string) error {
  378. _, err := model.DB.Exec(
  379. "DELETE FROM api_tokens WHERE tenant_id = $1 AND keyName = $2",
  380. tenantId, keyName,
  381. )
  382. if err != nil {
  383. return errors.Wrap(err, "DeleteAPIToken(): ")
  384. }
  385. return nil
  386. }
  387. func (model *DBModel) GetNextReservedPorts() (int, int, int, error) {
  388. port := 0
  389. bucket := 0
  390. err := model.DB.QueryRow("SELECT port, bucket FROM reserved_ports_counter").Scan(&port, &bucket)
  391. if err != nil {
  392. return 0, -1, -1, errors.Wrap(err, "GetNextReservedPorts(): ")
  393. }
  394. if port+model.PortRangeSize >= model.MaxListeningPort {
  395. bucket++
  396. port = model.MinListeningPort
  397. }
  398. _, err = model.DB.Exec("UPDATE reserved_ports_counter SET port = $1, bucket = $2", port+model.PortRangeSize, bucket)
  399. if err != nil {
  400. return 0, -1, -1, errors.Wrap(err, "GetNextReservedPorts(): ")
  401. }
  402. return port, port + (model.PortRangeSize - 1), bucket, nil
  403. }
  404. func (model *DBModel) GetVPSInstances() (map[string]*VPSInstance, error) {
  405. rows, err := model.DB.Query(`
  406. SELECT service_provider, provider_instance_id, tenant_id, ipv4, ipv6, bytes_monthly, created, deprecated, deleted
  407. FROM vps_instances WHERE deleted = FALSE
  408. `,
  409. )
  410. if err != nil {
  411. return nil, errors.Wrap(err, "GetVPSInstances(): ")
  412. }
  413. defer rows.Close()
  414. toReturn := map[string]*VPSInstance{}
  415. for rows.Next() {
  416. instance, err := model.rowToVPSInstance(rows)
  417. if err != nil {
  418. return nil, errors.Wrap(err, "GetVPSInstances(): ")
  419. }
  420. toReturn[instance.GetId()] = instance
  421. }
  422. return toReturn, nil
  423. }
  424. func (model *DBModel) CreateVPSInstance(toCreate *VPSInstance) error {
  425. _, err := model.DB.Exec(`
  426. INSERT INTO vps_instances (
  427. service_provider, provider_instance_id, created,
  428. ipv4, ipv6, bytes_monthly
  429. )
  430. VALUES($1, $2, $3,
  431. $4, $5, $6)
  432. `, toCreate.ServiceProvider, toCreate.ProviderInstanceId, toCreate.Created,
  433. toCreate.IPV4, toCreate.IPV6, toCreate.BytesMonthly,
  434. )
  435. if err != nil {
  436. return errors.Wrap(err, "CreateVPSInstance(): ")
  437. }
  438. if toCreate.TenantId != 0 {
  439. _, err := model.DB.Exec(
  440. `UPDATE vps_instances SET tenantId = $1 WHERE service_provider = $2 AND provider_instance_id = $3`,
  441. toCreate.TenantId, toCreate.ServiceProvider, toCreate.ProviderInstanceId,
  442. )
  443. if err != nil {
  444. return errors.Wrap(err, "CreateVPSInstance(): ")
  445. }
  446. }
  447. return nil
  448. }
  449. func (model *DBModel) DeleteVPSInstance(provider, providerInstanceId string) error {
  450. result, err := model.DB.Exec(
  451. `UPDATE vps_instances SET deleted = TRUE WHERE service_provider = $1 AND provider_instance_id = $2`,
  452. provider, providerInstanceId,
  453. )
  454. if err != nil {
  455. return errors.Wrap(err, "DeleteVPSInstance(): ")
  456. }
  457. rowsAffected, err := result.RowsAffected()
  458. if err != nil {
  459. return errors.Wrap(err, "DeleteVPSInstance(): ")
  460. }
  461. if rowsAffected == 0 {
  462. return errors.Errorf("DeleteVPSInstance(): '%s-%s' was not found", provider, providerInstanceId)
  463. }
  464. return nil
  465. }
  466. func (model *DBModel) GetTenants() (map[int]*TenantInfo, error) {
  467. rows, err := model.DB.Query(`SELECT tenant_id, domain_name, last_verified FROM external_domains`)
  468. if err != nil {
  469. return nil, errors.Wrap(err, "GetTenants(): ")
  470. }
  471. defer rows.Close()
  472. verificationCutoff := time.Now().UTC().Add(-(DomainVerificationPollingInterval + time.Minute))
  473. authorizedDomains := map[int][]string{}
  474. for rows.Next() {
  475. var tenantId int
  476. var domainName string
  477. var lastVerified time.Time
  478. err := rows.Scan(&tenantId, &domainName, &lastVerified)
  479. if err != nil {
  480. return nil, errors.Wrap(err, "GetTenants(): ")
  481. }
  482. if lastVerified.After(verificationCutoff) {
  483. if _, has := authorizedDomains[tenantId]; !has {
  484. authorizedDomains[tenantId] = []string{domainName}
  485. } else {
  486. authorizedDomains[tenantId] = append(authorizedDomains[tenantId], domainName)
  487. }
  488. }
  489. }
  490. rows, err = model.DB.Query(`SELECT id, created, subdomain, service_limit_cents, port_start, port_end, port_bucket FROM tenants`)
  491. if err != nil {
  492. return nil, errors.Wrap(err, "GetTenants(): ")
  493. }
  494. defer rows.Close()
  495. toReturn := map[int]*TenantInfo{}
  496. for rows.Next() {
  497. var tenantId int
  498. var tenantCreated time.Time
  499. var subdomain *string
  500. var serviceLimitCents int
  501. var portStart int
  502. var portEnd int
  503. var portBucket int
  504. err := rows.Scan(&tenantId, &tenantCreated, &subdomain, &serviceLimitCents, &portStart, &portEnd, &portBucket)
  505. if err != nil {
  506. return nil, errors.Wrap(err, "GetTenants(): ")
  507. }
  508. thisTenantDomains := authorizedDomains[tenantId]
  509. if thisTenantDomains == nil {
  510. thisTenantDomains = []string{}
  511. }
  512. // thisTenantPorts := reservedPorts[tenantId]
  513. // if thisTenantPorts == nil {
  514. // thisTenantPorts = []PortRange{}
  515. // }
  516. subdomainString := ""
  517. if subdomain != nil {
  518. subdomainString = *subdomain
  519. thisTenantDomains = append(thisTenantDomains, fmt.Sprintf("%s.%s", subdomainString, freeSubdomainDomain))
  520. }
  521. toReturn[tenantId] = &TenantInfo{
  522. Id: tenantId,
  523. Created: tenantCreated,
  524. Subdomain: subdomainString,
  525. ServiceLimitCents: serviceLimitCents,
  526. PortBucket: portBucket,
  527. TunnelSettings: &TunnelSettings{
  528. PortStart: portStart,
  529. PortEnd: portEnd,
  530. AuthorizedDomains: thisTenantDomains,
  531. },
  532. }
  533. }
  534. return toReturn, nil
  535. }
  536. func (model *DBModel) GetTenant(tenantId int) (*TenantInfo, error) {
  537. rows, err := model.DB.Query(
  538. `SELECT domain_name, last_verified FROM external_domains WHERE tenant_id = $1`,
  539. tenantId,
  540. )
  541. if err != nil {
  542. return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
  543. }
  544. defer rows.Close()
  545. verificationCutoff := time.Now().UTC().Add(-(DomainVerificationPollingInterval + time.Minute))
  546. authorizedDomains := []string{}
  547. externalDomains := []ExternalDomain{}
  548. for rows.Next() {
  549. var domainName string
  550. var lastVerified time.Time
  551. err := rows.Scan(&domainName, &lastVerified)
  552. if err != nil {
  553. return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
  554. }
  555. verified := lastVerified.After(verificationCutoff)
  556. if verified {
  557. authorizedDomains = append(authorizedDomains, domainName)
  558. }
  559. externalDomains = append(externalDomains, ExternalDomain{DomainName: domainName, IsValid: verified})
  560. }
  561. tokensRows, err := model.DB.Query(
  562. `SELECT key_name, hashed_token, active, created, last_used FROM api_tokens WHERE tenant_id = $1`,
  563. tenantId,
  564. )
  565. if err != nil {
  566. return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
  567. }
  568. defer tokensRows.Close()
  569. apiTokens := []APIToken{}
  570. for tokensRows.Next() {
  571. var keyName string
  572. var hashedToken string
  573. var active bool
  574. var created time.Time
  575. var lastUsed time.Time
  576. err := tokensRows.Scan(&keyName, &hashedToken, &active, &created, &lastUsed)
  577. if err != nil {
  578. return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
  579. }
  580. apiTokens = append(apiTokens, APIToken{
  581. Name: keyName,
  582. HashedToken: hashedToken,
  583. Active: active,
  584. Created: created,
  585. LastUsed: lastUsed,
  586. })
  587. }
  588. var created time.Time
  589. var subdomain *string
  590. var email string
  591. var smsAlarmNumber *string
  592. var billingAlarmCents int
  593. var serviceLimitCents int
  594. var portStart int
  595. var portEnd int
  596. var portBucket int
  597. err = model.DB.QueryRow(
  598. `SELECT created, email, subdomain, sms_alarm_number, billing_alarm_cents, service_limit_cents,
  599. port_start, port_end, port_bucket
  600. FROM tenants WHERE id = $1`,
  601. tenantId,
  602. ).Scan(
  603. &created, &email, &subdomain, &smsAlarmNumber, &billingAlarmCents, &serviceLimitCents,
  604. &portStart, &portEnd, &portBucket,
  605. )
  606. if err != nil {
  607. return nil, errors.Wrapf(err, "GetTenant(%d): ", tenantId)
  608. }
  609. subdomainString := ""
  610. if subdomain != nil {
  611. subdomainString = *subdomain
  612. authorizedDomains = append(authorizedDomains, fmt.Sprintf("%s.%s", subdomainString, freeSubdomainDomain))
  613. }
  614. smsString := ""
  615. if smsAlarmNumber != nil {
  616. smsString = *smsAlarmNumber
  617. }
  618. return &TenantInfo{
  619. Id: tenantId,
  620. Email: email,
  621. Created: created,
  622. SMSAlarmNumber: smsString,
  623. Subdomain: subdomainString,
  624. BillingAlarmCents: billingAlarmCents,
  625. ServiceLimitCents: serviceLimitCents,
  626. PortBucket: portBucket,
  627. TunnelSettings: &TunnelSettings{
  628. PortStart: portStart,
  629. PortEnd: portEnd,
  630. AuthorizedDomains: authorizedDomains,
  631. },
  632. APITokens: apiTokens,
  633. ExternalDomains: externalDomains,
  634. }, nil
  635. }
  636. func (model *DBModel) AddExternalDomain(tenantId int, externalDomain string) error {
  637. _, err := model.DB.Exec("INSERT INTO external_domains (tenant_id, domain_name) VALUES ($1, $2)", tenantId, externalDomain)
  638. if err != nil {
  639. return errors.Wrap(err, "AddExternalDomain(): ")
  640. }
  641. return nil
  642. }
  643. func (model *DBModel) GetExternalDomains() ([][]string, error) {
  644. rows, err := model.DB.Query(`SELECT id, subdomain FROM tenants`)
  645. if err != nil {
  646. return nil, errors.Wrap(err, "GetExternalDomains(): ")
  647. }
  648. defer rows.Close()
  649. personalDomainsByTenant := map[int]string{}
  650. for rows.Next() {
  651. var tenantId int
  652. var subdomain *string
  653. err := rows.Scan(&tenantId, &subdomain)
  654. if err != nil {
  655. return nil, errors.Wrap(err, "GetExternalDomains(): ")
  656. }
  657. if subdomain != nil {
  658. personalDomainsByTenant[tenantId] = fmt.Sprintf("%s.%s", *subdomain, freeSubdomainDomain)
  659. }
  660. }
  661. externalDomainsRows, err := model.DB.Query(`SELECT tenant_id, domain_name FROM external_domains`)
  662. if err != nil {
  663. return nil, errors.Wrap(err, "GetTenants(): ")
  664. }
  665. defer externalDomainsRows.Close()
  666. toReturn := [][]string{}
  667. for externalDomainsRows.Next() {
  668. var tenantId int
  669. var externalDomain string
  670. err := externalDomainsRows.Scan(&tenantId, &externalDomain)
  671. if err != nil {
  672. return nil, errors.Wrap(err, "GetExternalDomains(): ")
  673. }
  674. personalDomain, hasPersonalDomain := personalDomainsByTenant[tenantId]
  675. if hasPersonalDomain {
  676. toReturn = append(toReturn, []string{externalDomain, personalDomain})
  677. }
  678. }
  679. return toReturn, nil
  680. }
  681. func (model *DBModel) MarkExternalDomainAsVerified(externalDomain string) error {
  682. result, err := model.DB.Exec("UPDATE external_domains SET last_verified = NOW() WHERE domain_name = $1", externalDomain)
  683. if err != nil {
  684. return errors.Wrap(err, "MarkExternalDomainAsVerified(): ")
  685. }
  686. affected, err := result.RowsAffected()
  687. if err != nil {
  688. return errors.Wrap(err, "MarkExternalDomainAsVerified(): ")
  689. }
  690. if affected != 1 {
  691. return errors.Errorf("zero rows were affected by MarkExternalDomainAsVerified('%s')", externalDomain)
  692. }
  693. return nil
  694. }
  695. func (model *DBModel) GetTenantVPSInstanceRows(billingYear, billingMonth int) ([]*TenantVPSInstance, error) {
  696. // tenantCondition := ""
  697. // if tenantId > 0 {
  698. // tenantCondition = "AND tenant_vps_instance.tenant_id = $3"
  699. // }
  700. rows, err := model.DB.Query(`
  701. SELECT
  702. tenant_id,
  703. service_provider,
  704. provider_instance_id,
  705. shadow_config,
  706. bytes,
  707. active,
  708. deactivated_at
  709. FROM tenant_vps_instance
  710. WHERE billing_year = $1 AND billing_month = $2
  711. `, billingYear, billingMonth,
  712. )
  713. if err != nil {
  714. return nil, errors.Wrap(err, "GetTenantVPSInstanceRows(): ")
  715. }
  716. defer rows.Close()
  717. toReturn := []*TenantVPSInstance{}
  718. for rows.Next() {
  719. var tenantId int
  720. var serviceProvider string
  721. var serviceProviderInstanceId string
  722. var shadowConfigString string
  723. var bytes int64
  724. var active bool
  725. var deactivatedAt *time.Time
  726. err := rows.Scan(&tenantId, &serviceProvider, &serviceProviderInstanceId, &shadowConfigString, &bytes, &active, &deactivatedAt)
  727. if err != nil {
  728. return nil, errors.Wrap(err, "GetTenantVPSInstanceRows(): ")
  729. }
  730. var shadowConfig TunnelSettings
  731. err = json.Unmarshal([]byte(shadowConfigString), &shadowConfig)
  732. if err != nil {
  733. return nil, errors.Wrap(err, "GetTenantVPSInstanceRows(): ")
  734. }
  735. toReturn = append(toReturn, &TenantVPSInstance{
  736. TenantId: tenantId,
  737. ServiceProvider: serviceProvider,
  738. ProviderInstanceId: serviceProviderInstanceId,
  739. ShadowConfig: &shadowConfig,
  740. Bytes: bytes,
  741. Active: active,
  742. DeactivatedAt: deactivatedAt,
  743. })
  744. }
  745. return toReturn, nil
  746. }
  747. func (model *DBModel) RecordTenantsUsage(usage map[int]int64) error {
  748. actions := []func() taskResult{}
  749. for tenantId, bytez := range usage {
  750. tenantId := tenantId
  751. bytez := bytez
  752. actions = append(actions, func() taskResult {
  753. _, err := model.DB.Exec(`INSERT INTO tenant_metrics_bandwidth (tenant_id, bytes) VALUES ($1, $2)`, tenantId, bytez)
  754. return taskResult{Name: strconv.Itoa(tenantId), Err: err}
  755. })
  756. }
  757. results := doInParallel(false, actions...)
  758. errorStrings := []string{}
  759. for tenantId, result := range results {
  760. if result.Err != nil {
  761. errorStrings = append(errorStrings, fmt.Sprintf("tenant %s: %+v", tenantId, result.Err))
  762. }
  763. }
  764. if len(errorStrings) != 0 {
  765. return errors.Errorf("RecordTenantsUsage(): \n%s\n", strings.Join(errorStrings, "\n"))
  766. }
  767. return nil
  768. }
  769. func (model *DBModel) RecordVPSUsage(instance *VPSInstance, usage ThresholdMetrics, billingYear int, billingMonth int) error {
  770. bytesByTenant := map[string]int64{}
  771. for k, v := range usage.InboundByTenant {
  772. bytesByTenant[k] += v
  773. }
  774. for k, v := range usage.OutboundByTenant {
  775. bytesByTenant[k] += v
  776. }
  777. actions := []func() taskResult{}
  778. for tenantIdString, bytez := range bytesByTenant {
  779. tenantIdInt, err := strconv.Atoi(tenantIdString)
  780. if err != nil {
  781. return errors.Wrap(err, "RecordVPSUsage(): ")
  782. }
  783. bytez := bytez
  784. actions = append(actions, func() taskResult {
  785. result, err := model.DB.Exec(
  786. `UPDATE tenant_vps_instance SET bytes = bytes + $1
  787. WHERE service_provider = $2 AND provider_instance_id = $3 AND tenant_id = $4 AND billing_year = $5 AND billing_month = $6;`,
  788. bytez, instance.ServiceProvider, instance.ProviderInstanceId, tenantIdInt, billingYear, billingMonth,
  789. )
  790. if err != nil {
  791. return taskResult{Name: strconv.Itoa(tenantIdInt), Err: err}
  792. }
  793. rowsAffected, err := result.RowsAffected()
  794. if err != nil {
  795. return taskResult{Name: strconv.Itoa(tenantIdInt), Err: err}
  796. }
  797. if rowsAffected != 1 {
  798. return taskResult{
  799. Name: strconv.Itoa(tenantIdInt),
  800. Err: errors.Errorf(
  801. "tenant_vps_instance row not found for vps '%s' tenant '%d'",
  802. instance.GetId(), tenantIdInt,
  803. ),
  804. }
  805. }
  806. return taskResult{Name: strconv.Itoa(tenantIdInt)}
  807. })
  808. }
  809. results := doInParallel(false, actions...)
  810. errorStrings := []string{}
  811. for tenantId, result := range results {
  812. if result.Err != nil {
  813. errorStrings = append(errorStrings, fmt.Sprintf("tenant %s: %+v", tenantId, result.Err))
  814. }
  815. }
  816. if len(errorStrings) != 0 {
  817. return errors.Errorf("RecordVPSUsage(): \n%s\n", strings.Join(errorStrings, "\n"))
  818. }
  819. return nil
  820. }
  821. func (model *DBModel) GetTenantUsageTotal(tenantId int, billingYear, billingMonth int) (int64, error) {
  822. rows, err := model.DB.Query(`
  823. SELECT bytes FROM tenant_vps_instance WHERE tenant_id = $1 AND billing_year = $2 AND billing_month = $3
  824. `, tenantId, billingYear, billingMonth,
  825. )
  826. if err != nil {
  827. return 0, errors.Wrap(err, "GetTenantUsage(): ")
  828. }
  829. defer rows.Close()
  830. var monthlyBytes int64
  831. for rows.Next() {
  832. var bytes int64
  833. err := rows.Scan(&bytes)
  834. if err != nil {
  835. return 0, errors.Wrap(err, "GetTenantUsage(): ")
  836. }
  837. monthlyBytes += bytes
  838. }
  839. return monthlyBytes, nil
  840. }
  841. func (model *DBModel) GetTenantUsageMetrics(tenantId int, start, end time.Time) (map[time.Time]int64, error) {
  842. rows, err := model.DB.Query(`
  843. SELECT measured, bytes FROM tenant_metrics_bandwidth WHERE tenant_id = $1 AND measured > $2 AND measured < $3
  844. `, tenantId, start, end,
  845. )
  846. if err != nil {
  847. return nil, errors.Wrap(err, "GetTenantUsageMetrics(): ")
  848. }
  849. defer rows.Close()
  850. toReturn := map[time.Time]int64{}
  851. for rows.Next() {
  852. var measured time.Time
  853. var bytes int64
  854. err := rows.Scan(&measured, &bytes)
  855. if err != nil {
  856. return nil, errors.Wrap(err, "GetTenantUsageMetrics(): ")
  857. }
  858. toReturn[measured] = bytes
  859. }
  860. return toReturn, nil
  861. }
  862. func (model *DBModel) SaveInstanceConfiguration(
  863. billingYear int,
  864. billingMonth int,
  865. instance *VPSInstance,
  866. config map[int]*TunnelSettings,
  867. ) error {
  868. // first we set shadow_config & active=true for all tenants mentioned in the config
  869. tenantIds := []int{}
  870. for tenantId, tunnelSettings := range config {
  871. tenantIds = append(tenantIds, tenantId)
  872. shadowConfigBytes, err := json.Marshal(tunnelSettings)
  873. if err != nil {
  874. return errors.Wrapf(err, "cant serialize shadow config for tenant %s on %s", tenantId, instance.GetId())
  875. }
  876. shadowConfig := string(shadowConfigBytes)
  877. _, err = model.DB.Exec(`
  878. INSERT INTO tenant_vps_instance (
  879. billing_year, billing_month, tenant_id, service_provider, provider_instance_id,
  880. shadow_config, bytes, active
  881. )
  882. VALUES($1, $2, $3, $4, $5,
  883. $6, 0, $7)
  884. ON CONFLICT ON CONSTRAINT pk_tenant_vps_instance
  885. DO
  886. UPDATE SET shadow_config = $6, active = $7;
  887. `,
  888. billingYear, billingMonth, tenantId, instance.ServiceProvider, instance.ProviderInstanceId,
  889. shadowConfig, true,
  890. )
  891. if err != nil {
  892. return errors.Wrap(err, "SaveInstanceConfiguration(): ")
  893. }
  894. }
  895. // next, we disable all existing tenants for this instance which are not mentioned in the config
  896. tenantIdsStrings := make([]string, len(tenantIds))
  897. for i, id := range tenantIds {
  898. tenantIdsStrings[i] = strconv.Itoa(id)
  899. }
  900. _, err := model.DB.Exec(
  901. fmt.Sprintf(`
  902. UPDATE tenant_vps_instance SET active = FALSE, deactivated_at = NOW()
  903. WHERE billing_year = $1 AND billing_month = $2 AND service_provider = $3 AND provider_instance_id = $4
  904. AND tenant_id NOT IN (%s)
  905. `, strings.Join(tenantIdsStrings, ", ")),
  906. billingYear, billingMonth, instance.ServiceProvider, instance.ProviderInstanceId,
  907. )
  908. return errors.Wrap(err, "SaveInstanceConfiguration(): ")
  909. }
  910. func (model *DBModel) PollScheduledTask(name string, every time.Duration) (bool, error) {
  911. rows, err := model.DB.Query(`SELECT last_started, last_succeeded FROM scheduled_tasks WHERE name = $1`, name)
  912. if err != nil {
  913. return false, errors.Wrap(err, "PollScheduledTask(): ")
  914. }
  915. defer rows.Close()
  916. if rows.Next() {
  917. var lastStarted time.Time
  918. var lastSucceeded time.Time
  919. err := rows.Scan(&lastStarted, &lastSucceeded)
  920. //log.Printf("st %s, sc %s, since: %s\n", lastStarted, lastSucceeded, time.Since(lastSucceeded))
  921. if err != nil {
  922. return false, errors.Wrap(err, "PollScheduledTask(): ")
  923. }
  924. if time.Since(lastSucceeded) > every {
  925. _, err := model.DB.Exec("UPDATE scheduled_tasks SET last_started = $1 WHERE name = $2", time.Now().UTC(), name)
  926. if err != nil {
  927. return false, errors.Wrap(err, "PollScheduledTask(): ")
  928. }
  929. return true, nil
  930. }
  931. return false, nil
  932. } else {
  933. unixEpoch := time.Date(1970, 1, 1, 0, 0, 0, 1, time.Now().UTC().Location())
  934. _, err := model.DB.Exec("INSERT INTO scheduled_tasks (name, last_succeeded) VALUES ($1, $2)", name, unixEpoch)
  935. if err != nil {
  936. return false, errors.Wrap(err, "PollScheduledTask(): ")
  937. }
  938. return true, nil
  939. }
  940. }
  941. func (model *DBModel) ScheduledTaskCompleted(name string) error {
  942. _, err := model.DB.Exec("UPDATE scheduled_tasks SET last_succeeded = $1 WHERE name = $2", time.Now().UTC(), name)
  943. if err != nil {
  944. return errors.Wrap(err, "PollScheduledTask(): ")
  945. }
  946. return nil
  947. }
  948. func (model *DBModel) PutKeyPair(caName, name string, key, cert []byte) error {
  949. _, err := model.DB.Exec(`
  950. INSERT INTO pki_key_pairs (ca_name, name, key_bytes, cert_bytes)
  951. VALUES($1, $2, $3, $4)
  952. ON CONFLICT ON CONSTRAINT pk_pki_key_pairs
  953. DO
  954. UPDATE SET key_bytes = $3, cert_bytes = $4;
  955. `,
  956. caName, name, key, cert,
  957. )
  958. return errors.Wrap(err, "PutKeyPair(): ")
  959. }
  960. func (model *DBModel) GetServerKeyPair(caName, name string) ([]byte, []byte, error) {
  961. var key, cert []byte
  962. rows, err := model.DB.Query(`
  963. SELECT key_bytes, cert_bytes FROM pki_key_pairs
  964. WHERE ca_name = $1 AND name = $2
  965. `,
  966. caName, name,
  967. )
  968. if err != nil {
  969. return nil, nil, errors.Wrap(err, "GetServerKeyPair(): ")
  970. }
  971. defer rows.Close()
  972. for rows.Next() {
  973. err = rows.Scan(&key, &cert)
  974. if err != nil {
  975. return nil, nil, errors.Wrap(err, "GetServerKeyPair(): ")
  976. }
  977. return key, cert, nil
  978. }
  979. return nil, nil, nil
  980. }
  981. func (model *DBModel) rowToVPSInstance(row RowScanner) (*VPSInstance, error) {
  982. var serviceProvider string
  983. var providerInstanceId string
  984. var tenantId *int
  985. var ipv4 string
  986. var ipv6 string
  987. var bytesMonthly int64
  988. var created time.Time
  989. var deprecated bool
  990. var deleted bool
  991. err := row.Scan(&serviceProvider, &providerInstanceId, &tenantId, &ipv4, &ipv6, &bytesMonthly, &created, &deprecated, &deleted)
  992. if err != nil {
  993. return nil, errors.Wrap(err, "rowToVPSInstance(): ")
  994. }
  995. tenantIdInt := 0
  996. if tenantId != nil {
  997. tenantIdInt = *tenantId
  998. }
  999. return &VPSInstance{
  1000. ServiceProvider: serviceProvider,
  1001. ProviderInstanceId: providerInstanceId,
  1002. TenantId: tenantIdInt,
  1003. IPV4: ipv4,
  1004. IPV6: ipv6,
  1005. BytesMonthly: bytesMonthly,
  1006. Created: created,
  1007. Deprecated: deprecated,
  1008. Deleted: deleted,
  1009. }, nil
  1010. }
  1011. func domainNamesToString(slice []string) string {
  1012. sort.Strings(slice)
  1013. return strings.Join(slice, ",")
  1014. }