aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs308
1 files changed, 241 insertions, 67 deletions
diff --git a/src/main.rs b/src/main.rs
index 932a29f..7edd019 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -21,7 +21,7 @@
mod database;
mod datastructures;
-use crate::datastructures::{Config, Cookie, FormData};
+use crate::datastructures::{Config, Cookie, FormData, TestSuite};
use anyhow::Result;
use argon2::password_hash::PasswordHash;
use clap::{App, Arg, ArgMatches, SubCommand};
@@ -51,12 +51,12 @@ impl<R: BufRead, W: Write> IOModule<R, W> {
// Read stdin from upstream.
let mut buffer = String::new();
self.reader.read_to_string(&mut buffer)?;
+
//log::debug!("{}", buffer);
let data = datastructures::FormData::from(buffer);
- let redis_conn = redis::Client::open("redis://127.0.0.1/")?;
- let ret = verify_login(&cfg, &data, redis_conn.clone()).await;
+ let ret = verify_login(&cfg, &data).await;
if let Err(ref e) = ret {
eprintln!("{:?}", e);
@@ -64,6 +64,7 @@ impl<R: BufRead, W: Write> IOModule<R, W> {
}
if ret.unwrap_or(false) {
+ let redis_conn = redis::Client::open("redis://127.0.0.1/")?;
let cookie = Cookie::generate(data.get_user());
let mut conn = redis_conn.get_async_connection().await?;
@@ -103,6 +104,7 @@ impl<R: BufRead, W: Write> IOModule<R, W> {
// Processing the `authenticate-cookie` called by cgit.
async fn cmd_authenticate_cookie(matches: &ArgMatches<'_>, cfg: Config) -> Result<bool> {
let cookies = matches.value_of("http-cookie").unwrap_or("");
+ let repo = matches.value_of("repo").unwrap_or("");
let mut bypass = false;
@@ -110,7 +112,7 @@ async fn cmd_authenticate_cookie(matches: &ArgMatches<'_>, cfg: Config) -> Resul
bypass = true;
}
- if bypass {
+ if bypass || !cfg.check_repo_protect(repo){
return Ok(true);
}
@@ -121,6 +123,26 @@ async fn cmd_authenticate_cookie(matches: &ArgMatches<'_>, cfg: Config) -> Resul
let redis_conn = redis::Client::open("redis://127.0.0.1/")?;
let mut conn = redis_conn.get_async_connection().await?;
+ if !repo.is_empty() {
+ let key = format!("cgit_repo_{}", repo);
+ if !conn.exists(&key).await? {
+ let mut sql_conn = SqliteConnectOptions::from_str(cfg.get_database_location())?
+ .read_only(true)
+ .connect()
+ .await?;
+ if let Some((users, )) =
+ sqlx::query_as::<_, (String, )>(r#"SELECT "users" FROM "repos" WHERE "repo" = ? "#)
+ .bind(repo)
+ .fetch_optional(&mut sql_conn)
+ .await?
+ {
+ let iter = users.split_whitespace().collect::<Vec<&str>>();
+ conn.sadd(&key, iter).await?;
+ }
+ }
+ // TODO: redis repository ACL check should goes here
+ }
+
if let Ok(Some(cookie)) = Cookie::load_from_request(cookies) {
if let Ok(r) = conn
.get::<_, String>(format!("cgit_auth_{}", cookie.get_key()))
@@ -167,7 +189,7 @@ async fn cmd_init(cfg: Config) -> Result<()> {
Ok(())
}
-async fn verify_login(cfg: &Config, data: &FormData, redis_conn: redis::Client) -> Result<bool> {
+async fn verify_login(cfg: &Config, data: &FormData) -> Result<bool> {
if !cfg.test {
let last_copied = cfg.get_last_copy_timestamp().await.unwrap_or(0);
if last_copied == 0 || cfg.get_last_commit_timestamp().await.unwrap_or(0) != last_copied {
@@ -179,8 +201,6 @@ async fn verify_login(cfg: &Config, data: &FormData, redis_conn: redis::Client)
}
}
- let mut rd = redis_conn.get_async_connection().await?;
-
let mut conn = sqlx::sqlite::SqliteConnectOptions::from_str(
cfg.get_copied_database_location().to_str().unwrap(),
)?
@@ -189,26 +209,13 @@ async fn verify_login(cfg: &Config, data: &FormData, redis_conn: redis::Client)
.connect()
.await?;
- let (passwd_hash, uid) = sqlx::query_as::<_, (String, String)>(
- r#"SELECT "password", "uid" FROM "accounts" WHERE "user" = ?"#,
+ let (passwd_hash,) = sqlx::query_as::<_, (String,)>(
+ r#"SELECT "password" FROM "accounts" WHERE "user" = ?"#,
)
.bind(data.get_user())
.fetch_one(&mut conn)
.await?;
- let key = format!("cgit_repo_{}", data.get_user());
- if !rd.exists(&key).await? {
- if let Some((repos,)) =
- sqlx::query_as::<_, (String,)>(r#"SELECT "repos" FROM "repo" WHERE "uid" = ? "#)
- .bind(uid)
- .fetch_optional(&mut conn)
- .await?
- {
- let iter = repos.split_whitespace().collect::<Vec<&str>>();
- rd.sadd(&key, iter).await?;
- }
- }
-
let parsed_hash = PasswordHash::new(passwd_hash.as_str()).unwrap();
Ok(data.verify_password(&parsed_hash))
}
@@ -360,15 +367,15 @@ async fn cmd_reset_database(matches: &ArgMatches<'_>, cfg: Config) -> Result<()>
async fn cmd_upgrade_database(cfg: Config) -> Result<()> {
let tmp_dir = TempDir::new("rolling")?;
- let v1_path = tmp_dir.path().join("v1.db");
let v2_path = tmp_dir.path().join("v2.db");
+ let v3_path = tmp_dir.path().join("v3.db");
- drop(std::fs::File::create(&v2_path).expect("Create v2 database failure"));
+ drop(std::fs::File::create(&v3_path).expect("Create v3 database failure"));
- std::fs::copy(cfg.get_database_location(), &v1_path)
- .expect("Copy v1 database to tempdir failure");
+ std::fs::copy(cfg.get_database_location(), &v2_path)
+ .expect("Copy v2 database to tempdir failure");
- let mut origin_conn = SqliteConnectOptions::from_str(v1_path.as_path().to_str().unwrap())?
+ let mut origin_conn = SqliteConnectOptions::from_str(v2_path.as_path().to_str().unwrap())?
.read_only(true)
.connect()
.await?;
@@ -382,17 +389,16 @@ async fn cmd_upgrade_database(cfg: Config) -> Result<()> {
#[allow(deprecated)]
if v.eq(database::previous::VERSION) {
- let mut conn = SqliteConnection::connect(v2_path.as_path().to_str().unwrap()).await?;
+ let mut conn = SqliteConnection::connect(v3_path.as_path().to_str().unwrap()).await?;
sqlx::query(database::current::CREATE_TABLES)
.execute(&mut conn)
.await?;
- let mut iter = sqlx::query_as::<_, (String, String)>(r#"SELECT * FROM "accounts""#)
+ let mut iter = sqlx::query_as::<_, (String, String, String)>(r#"SELECT * FROM "accounts""#)
.fetch(&mut origin_conn);
- while let Some(Ok((user, passwd))) = iter.next().await {
- let uid = uuid::Uuid::new_v4().to_hyphenated().to_string();
+ while let Some(Ok((user, passwd, uid))) = iter.next().await {
sqlx::query(r#"INSERT INTO "accounts" VALUES (?, ?, ?)"#)
.bind(user.as_str())
.bind(passwd)
@@ -403,7 +409,7 @@ async fn cmd_upgrade_database(cfg: Config) -> Result<()> {
}
drop(conn);
- std::fs::copy(&v2_path, cfg.get_database_location())
+ std::fs::copy(&v3_path, cfg.get_database_location())
.expect("Copy back to database location failure");
println!("Upgrade database successful");
} else {
@@ -420,6 +426,96 @@ async fn cmd_upgrade_database(cfg: Config) -> Result<()> {
Ok(())
}
+async fn cmd_repo_user_control(matches: &ArgMatches<'_>, cfg: Config, is_delete: bool) -> Result<()> {
+ let repo = matches.value_of("repo").unwrap_or("");
+ let user = matches.value_of("user").unwrap_or("");
+
+ let clear_all = matches.is_present("clear-all");
+
+ if repo.is_empty() || (is_delete && !clear_all && user.is_empty()) || (!is_delete && user.is_empty()) {
+ return Err(anyhow::Error::msg("Invalid repository or username"));
+ }
+
+ let redis_client = redis::Client::open("redis://127.0.0.1/")?;
+ let mut redis_conn = redis_client.get_async_connection().await?;
+
+ let mut conn = SqliteConnection::connect(cfg.get_database_location()).await?;
+
+ if sqlx::query(r#"SELECT "users" FROM "repos" WHERE "repo" = ?"#)
+ .bind(repo)
+ .fetch_optional(&mut conn)
+ .await?
+ .is_none() {
+ if is_delete {
+ println!("Row is empty.");
+ return Ok(())
+ }
+ sqlx::query(r#"INSERT INTO "repos" VALUES (?, ?)"#)
+ .bind(repo)
+ .bind("")
+ .execute(&mut conn)
+ .await?;
+ }
+
+ let (users,) =
+ sqlx::query_as::<_, (String,)>(r#"SELECT "users" FROM "repos" WHERE "repo" = ?"#)
+ .bind(repo)
+ .fetch_optional(&mut conn)
+ .await?
+ .unwrap();
+ let mut users = users.split_whitespace().collect::<Vec<&str>>();
+
+ if let Some(index) = users.clone().into_iter().position(|x| x.eq(user)) {
+ if is_delete {
+ if clear_all {
+ users.clear();
+ } else {
+ users.remove(index);
+ }
+ } else {
+ return Err(anyhow::Error::msg("User already in repository ACL"));
+ }
+ }
+
+ if !is_delete {
+ users.push(user);
+ }
+
+ sqlx::query(r#"UPDATE "repos" SET "users" = ? WHERE "repo" = ?"#)
+ .bind(users.join(" "))
+ .bind(repo)
+ .execute(&mut conn)
+ .await?;
+
+ let redis_key = format!("cgit_repo_{}", repo);
+ if redis_conn.exists::<_, i32>(&redis_key).await? == 0{
+ redis_conn.sadd::<_, _, i32>(&redis_key, users).await?;
+ } else {
+ if is_delete {
+ if clear_all {
+ redis_conn.del::<_, i32>(&redis_key).await?;
+ } else {
+ redis_conn.srem::<_, _, i32>(&redis_key, user).await?;
+ }
+ } else {
+ redis_conn.sadd::<_, _, i32>(&redis_key, user).await?;
+ }
+ }
+
+ if !clear_all {
+ println!("{} user {} {} repository {} ACL successful",
+ if is_delete { "Delete" } else { "Add" },
+ user,
+ if is_delete { "from" } else { "to" },
+ repo,
+ );
+ } else {
+ println!("Clear all users from repository {} ACL", repo);
+ }
+
+ Ok(())
+}
+
async fn async_main(arg_matches: ArgMatches<'_>) -> Result<i32> {
let cfg = if std::env::args().any(|x| x.eq("--test")) {
Config::generate_test_config()
@@ -466,6 +562,13 @@ async fn async_main(arg_matches: ArgMatches<'_>) -> Result<i32> {
("upgrade", Some(_matches)) => {
cmd_upgrade_database(cfg).await?;
}
+ ("repoadd", Some(matches)) => {
+ cmd_repo_user_control(matches, cfg, false).await?
+ }
+ ("repodel", Some(matches)) => {
+ cmd_repo_user_control(matches, cfg, true).await?;
+ }
+ // TODO: other repository rated command
_ => {}
}
Ok(0)
@@ -524,7 +627,29 @@ fn get_arg_matches(arguments: Option<Vec<&str>>) -> ArgMatches {
)
.subcommand(
SubCommand::with_name("upgrade")
- .about("Upgrade database from v1(v0.1.x - v0.2.x) to v2(^v0.3.x)"),
+ .about("Upgrade database from v2(v0.3.x) to v3(^v0.4.x)"),
+ )
+ .subcommand(
+ SubCommand::with_name("repoadd")
+ .about("Add user to repository")
+ .arg(Arg::with_name("repo").required(true))
+ .arg(Arg::with_name("user").required(true)),
+ )
+ .subcommand(
+ SubCommand::with_name("repodel")
+ .about("Del user from repository")
+ .arg(Arg::with_name("repo").required(true))
+ .arg(Arg::with_name("user").takes_value(true))
+ .arg(
+ Arg::with_name("clear-all")
+ .long("--clear-all")
+ .conflicts_with("user"),
+ ),
+ )
+ .subcommand(
+ SubCommand::with_name("repos")
+ .about("Show all repositories or only show specify repository detail")
+ .arg(Arg::with_name("repo").takes_value(true)),
);
let matches = if let Some(args) = arguments {
@@ -588,8 +713,8 @@ fn main() -> Result<()> {
#[cfg(test)]
mod test {
- use crate::datastructures::{rand_str, Config};
- use crate::{cmd_add_user, cmd_init, cmd_authenticate_cookie};
+ use crate::datastructures::{rand_str, Config, TestSuite, ProtectedRepo};
+ use crate::{cmd_add_user, cmd_authenticate_cookie, cmd_init};
use crate::{get_arg_matches, IOModule};
use argon2::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
@@ -687,14 +812,14 @@ mod test {
}
#[test]
- fn test_1_auth_failure() {
+ fn test_01_auth_failure() {
let out = test_auth_post();
assert!(out.starts_with("Status: 403"));
assert!(out.ends_with("\n\n"));
}
#[test]
- fn test_0_init_database() {
+ fn test_00_init_database() {
let tmp_dir = Path::new("test");
if tmp_dir.exists() {
@@ -724,7 +849,7 @@ mod test {
}
#[test]
- fn test_2_insert_user() {
+ fn test_02_insert_user() {
lock(&PathBuf::from("test/DATABASE_INITED"), 3);
let matches = crate::get_arg_matches(Some(vec!["a", "adduser", "hunter2", "hunter2"]));
match matches.subcommand() {
@@ -743,7 +868,13 @@ mod test {
}
#[test]
- fn test_3_auth_pass() {
+ fn test_03_insert_repo() {
+ lock(&PathBuf::from("test/USER_WRITTEN"), 5);
+ let matches = crate::get_arg_matches(Some(vec!["a", "repoadd", "hunter2", "hunter2"]));
+ }
+
+ #[test]
+ fn test_91_auth_pass() {
lock(&PathBuf::from("test/USER_WRITTEN"), 7);
let s = test_auth_post();
@@ -761,7 +892,7 @@ mod test {
}
#[test]
- fn test_4_authenticate_cookie() {
+ fn test_92_authenticate_cookie() {
lock(&PathBuf::from("test/RESPONSE"), 15);
let mut buffer = String::new();
@@ -774,42 +905,85 @@ mod test {
for line in buffer.lines().map(|x| x.trim()) {
if !line.starts_with("Set-Cookie") {
- continue
+ continue;
}
let (_, value) = line.split_once(":").unwrap();
let (value, _) = value.split_once(";").unwrap();
cookie = value.trim();
- break
+ break;
}
- let matches = get_arg_matches(
- Some(
- vec![
- "a",
- "authenticate-cookie",
- cookie,
- "GET",
- "",
- "https://git.example.com/",
- "/",
- "git.example.com",
- "on",
- "",
- "",
- "/",
- "/?p=login",
+ let matches = get_arg_matches(Some(vec![
+ "a",
+ "authenticate-cookie",
+ cookie,
+ "GET",
+ "",
+ "https://git.example.com/",
+ "/",
+ "git.example.com",
+ "on",
+ "",
+ "",
+ "/",
+ "/?p=login",
]));
let result = match matches.subcommand() {
- ("authenticate-cookie", Some(matches)) => {
- tokio::runtime::Builder::new_current_thread()
- .enable_all()
- .build()
- .unwrap()
- .block_on(cmd_authenticate_cookie(matches, Config::generate_test_config()))
- .unwrap()
- }
- _ => unreachable!()
+ ("authenticate-cookie", Some(matches)) => tokio::runtime::Builder::new_current_thread()
+ .enable_all()
+ .build()
+ .unwrap()
+ .block_on(cmd_authenticate_cookie(
+ matches,
+ Config::generate_test_config(),
+ ))
+ .unwrap(),
+ _ => unreachable!(),
};
assert!(result);
}
+
+ fn write_to_specify_file(path: &PathBuf, data: &[u8]) -> Result<(), std::io::Error> {
+ let mut file = std::fs::OpenOptions::new()
+ .create(true)
+ .write(true)
+ .truncate(true)
+ .open(path)?;
+ file.write_all(data)?;
+ file.sync_all()?;
+ Ok(())
+ }
+
+ #[test]
+ fn test_02_protected_repo_parser() {
+ let tmpdir = tempdir::TempDir::new("test").unwrap();
+
+ let another_file_path = format!("include={}/REPO_SETTING # TEST", tmpdir.path().to_str().unwrap());
+ write_to_specify_file(&tmpdir.path().join("CFG"), another_file_path.as_bytes()).unwrap();
+
+
+ write_to_specify_file(&tmpdir.path().join("REPO_SETTING"), b"repo.url=test\nrepo.protect=true").unwrap();
+
+
+ let result = ProtectedRepo::load_from_file(tmpdir.path().join("CFG"));
+
+ assert!(result.query_is_protected("test"));
+ assert!(!result.query_is_all_protected());
+
+ write_to_specify_file(&tmpdir.path().join("REPO_SETTING"), b"repo.protect=true\nrepo.url=test").unwrap();
+
+ let result = ProtectedRepo::load_from_file(tmpdir.path().join("CFG"));
+
+ assert!(!result.query_is_protected("test"));
+ assert!(!result.query_is_all_protected());
+
+ write_to_specify_file(&tmpdir.path().join("REPO_SETTING"), b"repo.protect=true\nrepo.url=test\n\ncgit-simple-auth-protect=full").unwrap();
+
+ let result = ProtectedRepo::load_from_file(tmpdir.path().join("CFG"));
+
+ assert!(result.query_is_all_protected());
+ assert!(result.query_is_protected("test"));
+
+ tmpdir.close().unwrap();
+ }
}