diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 308 |
1 files changed, 241 insertions, 67 deletions
diff --git a/src/main.rs b/src/main.rs index 932a29f..7edd019 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,7 @@ mod database; mod datastructures; -use crate::datastructures::{Config, Cookie, FormData}; +use crate::datastructures::{Config, Cookie, FormData, TestSuite}; use anyhow::Result; use argon2::password_hash::PasswordHash; use clap::{App, Arg, ArgMatches, SubCommand}; @@ -51,12 +51,12 @@ impl<R: BufRead, W: Write> IOModule<R, W> { // Read stdin from upstream. let mut buffer = String::new(); self.reader.read_to_string(&mut buffer)?; + //log::debug!("{}", buffer); let data = datastructures::FormData::from(buffer); - let redis_conn = redis::Client::open("redis://127.0.0.1/")?; - let ret = verify_login(&cfg, &data, redis_conn.clone()).await; + let ret = verify_login(&cfg, &data).await; if let Err(ref e) = ret { eprintln!("{:?}", e); @@ -64,6 +64,7 @@ impl<R: BufRead, W: Write> IOModule<R, W> { } if ret.unwrap_or(false) { + let redis_conn = redis::Client::open("redis://127.0.0.1/")?; let cookie = Cookie::generate(data.get_user()); let mut conn = redis_conn.get_async_connection().await?; @@ -103,6 +104,7 @@ impl<R: BufRead, W: Write> IOModule<R, W> { // Processing the `authenticate-cookie` called by cgit. async fn cmd_authenticate_cookie(matches: &ArgMatches<'_>, cfg: Config) -> Result<bool> { let cookies = matches.value_of("http-cookie").unwrap_or(""); + let repo = matches.value_of("repo").unwrap_or(""); let mut bypass = false; @@ -110,7 +112,7 @@ async fn cmd_authenticate_cookie(matches: &ArgMatches<'_>, cfg: Config) -> Resul bypass = true; } - if bypass { + if bypass || !cfg.check_repo_protect(repo){ return Ok(true); } @@ -121,6 +123,26 @@ async fn cmd_authenticate_cookie(matches: &ArgMatches<'_>, cfg: Config) -> Resul let redis_conn = redis::Client::open("redis://127.0.0.1/")?; let mut conn = redis_conn.get_async_connection().await?; + if !repo.is_empty() { + let key = format!("cgit_repo_{}", repo); + if !conn.exists(&key).await? { + let mut sql_conn = SqliteConnectOptions::from_str(cfg.get_database_location())? + .read_only(true) + .connect() + .await?; + if let Some((users, )) = + sqlx::query_as::<_, (String, )>(r#"SELECT "users" FROM "repos" WHERE "repo" = ? "#) + .bind(repo) + .fetch_optional(&mut sql_conn) + .await? + { + let iter = users.split_whitespace().collect::<Vec<&str>>(); + conn.sadd(&key, iter).await?; + } + } + // TODO: redis repository ACL check should goes here + } + if let Ok(Some(cookie)) = Cookie::load_from_request(cookies) { if let Ok(r) = conn .get::<_, String>(format!("cgit_auth_{}", cookie.get_key())) @@ -167,7 +189,7 @@ async fn cmd_init(cfg: Config) -> Result<()> { Ok(()) } -async fn verify_login(cfg: &Config, data: &FormData, redis_conn: redis::Client) -> Result<bool> { +async fn verify_login(cfg: &Config, data: &FormData) -> Result<bool> { if !cfg.test { let last_copied = cfg.get_last_copy_timestamp().await.unwrap_or(0); if last_copied == 0 || cfg.get_last_commit_timestamp().await.unwrap_or(0) != last_copied { @@ -179,8 +201,6 @@ async fn verify_login(cfg: &Config, data: &FormData, redis_conn: redis::Client) } } - let mut rd = redis_conn.get_async_connection().await?; - let mut conn = sqlx::sqlite::SqliteConnectOptions::from_str( cfg.get_copied_database_location().to_str().unwrap(), )? @@ -189,26 +209,13 @@ async fn verify_login(cfg: &Config, data: &FormData, redis_conn: redis::Client) .connect() .await?; - let (passwd_hash, uid) = sqlx::query_as::<_, (String, String)>( - r#"SELECT "password", "uid" FROM "accounts" WHERE "user" = ?"#, + let (passwd_hash,) = sqlx::query_as::<_, (String,)>( + r#"SELECT "password" FROM "accounts" WHERE "user" = ?"#, ) .bind(data.get_user()) .fetch_one(&mut conn) .await?; - let key = format!("cgit_repo_{}", data.get_user()); - if !rd.exists(&key).await? { - if let Some((repos,)) = - sqlx::query_as::<_, (String,)>(r#"SELECT "repos" FROM "repo" WHERE "uid" = ? "#) - .bind(uid) - .fetch_optional(&mut conn) - .await? - { - let iter = repos.split_whitespace().collect::<Vec<&str>>(); - rd.sadd(&key, iter).await?; - } - } - let parsed_hash = PasswordHash::new(passwd_hash.as_str()).unwrap(); Ok(data.verify_password(&parsed_hash)) } @@ -360,15 +367,15 @@ async fn cmd_reset_database(matches: &ArgMatches<'_>, cfg: Config) -> Result<()> async fn cmd_upgrade_database(cfg: Config) -> Result<()> { let tmp_dir = TempDir::new("rolling")?; - let v1_path = tmp_dir.path().join("v1.db"); let v2_path = tmp_dir.path().join("v2.db"); + let v3_path = tmp_dir.path().join("v3.db"); - drop(std::fs::File::create(&v2_path).expect("Create v2 database failure")); + drop(std::fs::File::create(&v3_path).expect("Create v3 database failure")); - std::fs::copy(cfg.get_database_location(), &v1_path) - .expect("Copy v1 database to tempdir failure"); + std::fs::copy(cfg.get_database_location(), &v2_path) + .expect("Copy v2 database to tempdir failure"); - let mut origin_conn = SqliteConnectOptions::from_str(v1_path.as_path().to_str().unwrap())? + let mut origin_conn = SqliteConnectOptions::from_str(v2_path.as_path().to_str().unwrap())? .read_only(true) .connect() .await?; @@ -382,17 +389,16 @@ async fn cmd_upgrade_database(cfg: Config) -> Result<()> { #[allow(deprecated)] if v.eq(database::previous::VERSION) { - let mut conn = SqliteConnection::connect(v2_path.as_path().to_str().unwrap()).await?; + let mut conn = SqliteConnection::connect(v3_path.as_path().to_str().unwrap()).await?; sqlx::query(database::current::CREATE_TABLES) .execute(&mut conn) .await?; - let mut iter = sqlx::query_as::<_, (String, String)>(r#"SELECT * FROM "accounts""#) + let mut iter = sqlx::query_as::<_, (String, String, String)>(r#"SELECT * FROM "accounts""#) .fetch(&mut origin_conn); - while let Some(Ok((user, passwd))) = iter.next().await { - let uid = uuid::Uuid::new_v4().to_hyphenated().to_string(); + while let Some(Ok((user, passwd, uid))) = iter.next().await { sqlx::query(r#"INSERT INTO "accounts" VALUES (?, ?, ?)"#) .bind(user.as_str()) .bind(passwd) @@ -403,7 +409,7 @@ async fn cmd_upgrade_database(cfg: Config) -> Result<()> { } drop(conn); - std::fs::copy(&v2_path, cfg.get_database_location()) + std::fs::copy(&v3_path, cfg.get_database_location()) .expect("Copy back to database location failure"); println!("Upgrade database successful"); } else { @@ -420,6 +426,96 @@ async fn cmd_upgrade_database(cfg: Config) -> Result<()> { Ok(()) } +async fn cmd_repo_user_control(matches: &ArgMatches<'_>, cfg: Config, is_delete: bool) -> Result<()> { + let repo = matches.value_of("repo").unwrap_or(""); + let user = matches.value_of("user").unwrap_or(""); + + let clear_all = matches.is_present("clear-all"); + + if repo.is_empty() || (is_delete && !clear_all && user.is_empty()) || (!is_delete && user.is_empty()) { + return Err(anyhow::Error::msg("Invalid repository or username")); + } + + let redis_client = redis::Client::open("redis://127.0.0.1/")?; + let mut redis_conn = redis_client.get_async_connection().await?; + + let mut conn = SqliteConnection::connect(cfg.get_database_location()).await?; + + if sqlx::query(r#"SELECT "users" FROM "repos" WHERE "repo" = ?"#) + .bind(repo) + .fetch_optional(&mut conn) + .await? + .is_none() { + if is_delete { + println!("Row is empty."); + return Ok(()) + } + sqlx::query(r#"INSERT INTO "repos" VALUES (?, ?)"#) + .bind(repo) + .bind("") + .execute(&mut conn) + .await?; + } + + let (users,) = + sqlx::query_as::<_, (String,)>(r#"SELECT "users" FROM "repos" WHERE "repo" = ?"#) + .bind(repo) + .fetch_optional(&mut conn) + .await? + .unwrap(); + let mut users = users.split_whitespace().collect::<Vec<&str>>(); + + if let Some(index) = users.clone().into_iter().position(|x| x.eq(user)) { + if is_delete { + if clear_all { + users.clear(); + } else { + users.remove(index); + } + } else { + return Err(anyhow::Error::msg("User already in repository ACL")); + } + } + + if !is_delete { + users.push(user); + } + + sqlx::query(r#"UPDATE "repos" SET "users" = ? WHERE "repo" = ?"#) + .bind(users.join(" ")) + .bind(repo) + .execute(&mut conn) + .await?; + + let redis_key = format!("cgit_repo_{}", repo); + if redis_conn.exists::<_, i32>(&redis_key).await? == 0{ + redis_conn.sadd::<_, _, i32>(&redis_key, users).await?; + } else { + if is_delete { + if clear_all { + redis_conn.del::<_, i32>(&redis_key).await?; + } else { + redis_conn.srem::<_, _, i32>(&redis_key, user).await?; + } + } else { + redis_conn.sadd::<_, _, i32>(&redis_key, user).await?; + } + } + + if !clear_all { + println!("{} user {} {} repository {} ACL successful", + if is_delete { "Delete" } else { "Add" }, + user, + if is_delete { "from" } else { "to" }, + repo, + ); + } else { + println!("Clear all users from repository {} ACL", repo); + } + + Ok(()) +} + async fn async_main(arg_matches: ArgMatches<'_>) -> Result<i32> { let cfg = if std::env::args().any(|x| x.eq("--test")) { Config::generate_test_config() @@ -466,6 +562,13 @@ async fn async_main(arg_matches: ArgMatches<'_>) -> Result<i32> { ("upgrade", Some(_matches)) => { cmd_upgrade_database(cfg).await?; } + ("repoadd", Some(matches)) => { + cmd_repo_user_control(matches, cfg, false).await? + } + ("repodel", Some(matches)) => { + cmd_repo_user_control(matches, cfg, true).await?; + } + // TODO: other repository rated command _ => {} } Ok(0) @@ -524,7 +627,29 @@ fn get_arg_matches(arguments: Option<Vec<&str>>) -> ArgMatches { ) .subcommand( SubCommand::with_name("upgrade") - .about("Upgrade database from v1(v0.1.x - v0.2.x) to v2(^v0.3.x)"), + .about("Upgrade database from v2(v0.3.x) to v3(^v0.4.x)"), + ) + .subcommand( + SubCommand::with_name("repoadd") + .about("Add user to repository") + .arg(Arg::with_name("repo").required(true)) + .arg(Arg::with_name("user").required(true)), + ) + .subcommand( + SubCommand::with_name("repodel") + .about("Del user from repository") + .arg(Arg::with_name("repo").required(true)) + .arg(Arg::with_name("user").takes_value(true)) + .arg( + Arg::with_name("clear-all") + .long("--clear-all") + .conflicts_with("user"), + ), + ) + .subcommand( + SubCommand::with_name("repos") + .about("Show all repositories or only show specify repository detail") + .arg(Arg::with_name("repo").takes_value(true)), ); let matches = if let Some(args) = arguments { @@ -588,8 +713,8 @@ fn main() -> Result<()> { #[cfg(test)] mod test { - use crate::datastructures::{rand_str, Config}; - use crate::{cmd_add_user, cmd_init, cmd_authenticate_cookie}; + use crate::datastructures::{rand_str, Config, TestSuite, ProtectedRepo}; + use crate::{cmd_add_user, cmd_authenticate_cookie, cmd_init}; use crate::{get_arg_matches, IOModule}; use argon2::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, @@ -687,14 +812,14 @@ mod test { } #[test] - fn test_1_auth_failure() { + fn test_01_auth_failure() { let out = test_auth_post(); assert!(out.starts_with("Status: 403")); assert!(out.ends_with("\n\n")); } #[test] - fn test_0_init_database() { + fn test_00_init_database() { let tmp_dir = Path::new("test"); if tmp_dir.exists() { @@ -724,7 +849,7 @@ mod test { } #[test] - fn test_2_insert_user() { + fn test_02_insert_user() { lock(&PathBuf::from("test/DATABASE_INITED"), 3); let matches = crate::get_arg_matches(Some(vec!["a", "adduser", "hunter2", "hunter2"])); match matches.subcommand() { @@ -743,7 +868,13 @@ mod test { } #[test] - fn test_3_auth_pass() { + fn test_03_insert_repo() { + lock(&PathBuf::from("test/USER_WRITTEN"), 5); + let matches = crate::get_arg_matches(Some(vec!["a", "repoadd", "hunter2", "hunter2"])); + } + + #[test] + fn test_91_auth_pass() { lock(&PathBuf::from("test/USER_WRITTEN"), 7); let s = test_auth_post(); @@ -761,7 +892,7 @@ mod test { } #[test] - fn test_4_authenticate_cookie() { + fn test_92_authenticate_cookie() { lock(&PathBuf::from("test/RESPONSE"), 15); let mut buffer = String::new(); @@ -774,42 +905,85 @@ mod test { for line in buffer.lines().map(|x| x.trim()) { if !line.starts_with("Set-Cookie") { - continue + continue; } let (_, value) = line.split_once(":").unwrap(); let (value, _) = value.split_once(";").unwrap(); cookie = value.trim(); - break + break; } - let matches = get_arg_matches( - Some( - vec![ - "a", - "authenticate-cookie", - cookie, - "GET", - "", - "https://git.example.com/", - "/", - "git.example.com", - "on", - "", - "", - "/", - "/?p=login", + let matches = get_arg_matches(Some(vec![ + "a", + "authenticate-cookie", + cookie, + "GET", + "", + "https://git.example.com/", + "/", + "git.example.com", + "on", + "", + "", + "/", + "/?p=login", ])); let result = match matches.subcommand() { - ("authenticate-cookie", Some(matches)) => { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap() - .block_on(cmd_authenticate_cookie(matches, Config::generate_test_config())) - .unwrap() - } - _ => unreachable!() + ("authenticate-cookie", Some(matches)) => tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(cmd_authenticate_cookie( + matches, + Config::generate_test_config(), + )) + .unwrap(), + _ => unreachable!(), }; assert!(result); } + + fn write_to_specify_file(path: &PathBuf, data: &[u8]) -> Result<(), std::io::Error> { + let mut file = std::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(path)?; + file.write_all(data)?; + file.sync_all()?; + Ok(()) + } + + #[test] + fn test_02_protected_repo_parser() { + let tmpdir = tempdir::TempDir::new("test").unwrap(); + + let another_file_path = format!("include={}/REPO_SETTING # TEST", tmpdir.path().to_str().unwrap()); + write_to_specify_file(&tmpdir.path().join("CFG"), another_file_path.as_bytes()).unwrap(); + + + write_to_specify_file(&tmpdir.path().join("REPO_SETTING"), b"repo.url=test\nrepo.protect=true").unwrap(); + + + let result = ProtectedRepo::load_from_file(tmpdir.path().join("CFG")); + + assert!(result.query_is_protected("test")); + assert!(!result.query_is_all_protected()); + + write_to_specify_file(&tmpdir.path().join("REPO_SETTING"), b"repo.protect=true\nrepo.url=test").unwrap(); + + let result = ProtectedRepo::load_from_file(tmpdir.path().join("CFG")); + + assert!(!result.query_is_protected("test")); + assert!(!result.query_is_all_protected()); + + write_to_specify_file(&tmpdir.path().join("REPO_SETTING"), b"repo.protect=true\nrepo.url=test\n\ncgit-simple-auth-protect=full").unwrap(); + + let result = ProtectedRepo::load_from_file(tmpdir.path().join("CFG")); + + assert!(result.query_is_all_protected()); + assert!(result.query_is_protected("test")); + + tmpdir.close().unwrap(); + } } |