(* 
    Transactions for OCaml XenStore Daemon.
    Copyright (C) 2008 Patrick Colp University of British Columbia

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

type tr =
  {
    domain_id : int;
    transaction_id : int32
  }

type operation =
  | NONE
  | READ
  | WRITE
  | RM

type element =
  {
    transaction : tr;
    operation : operation;
    path : string;
    mutable modified : bool
  }

type changed_domain =
  {
    id : int;
    entries : int
  }

let equal t1 t2 =
  t1.domain_id = t2.domain_id && t1.transaction_id = t2.transaction_id

let fire_watch watches changed_node =
  match changed_node.operation with
  | RM -> watches#fire_watches changed_node.path false true
  | WRITE -> watches#fire_watches changed_node.path false false
  | _ -> ()

let fire_watches watches changed_nodes =
  List.iter (fire_watch watches) changed_nodes

let make domain_id transaction_id =
  {
    domain_id = domain_id;
    transaction_id = transaction_id
  }

let make_element transaction operation path =
  {
    transaction = transaction;
    operation = operation;
    path = path;
    modified = false
  }

module Transaction_hashtbl =
  Hashtbl.Make
  (struct
    type t = tr
    let equal = equal
    let hash = Hashtbl.hash
  end)

class transaction_reads =
object (self)
  val m_paths = Hashtbl.create 32
  val m_transactions = Transaction_hashtbl.create 8
  method private paths = m_paths
  method private transactions = m_transactions
  method add transaction path =
    let operation = make_element transaction READ path
    and paths = self#paths
    and transactions = self#transactions in
    let path_operations =
      if Hashtbl.mem paths path
      then
        let current_operations = Hashtbl.find paths path in
        if not (List.exists (fun op -> transaction = op.transaction) current_operations)
        then operation :: current_operations
        else current_operations
      else [ operation ]
    and transaction_operations =
      if Transaction_hashtbl.mem transactions transaction
      then
        let current_operations = Transaction_hashtbl.find transactions transaction in
        if not (List.exists (fun op -> path = op.path) current_operations)
        then operation :: current_operations
        else current_operations
      else [ operation ] in
    Hashtbl.replace paths path path_operations;
    Transaction_hashtbl.replace transactions transaction transaction_operations
  method path_operations path = Hashtbl.find self#paths path
  method remove_path_operation operation =
    let remaining = List.filter (fun op -> not (equal op.transaction operation.transaction)) (self#path_operations operation.path) in
    if List.length remaining > 0
    then Hashtbl.replace self#paths operation.path remaining
    else Hashtbl.remove self#paths operation.path
  method remove_transaction_operations transaction =
    (try List.iter self#remove_path_operation (self#transaction_operations transaction) with Not_found -> ());
    Transaction_hashtbl.remove self#transactions transaction
  method transaction_operations transaction = Transaction_hashtbl.find self#transactions transaction
end

class ['contents] transaction_store (transaction : tr) (store : 'contents Store.store) (reads : transaction_reads) =
object (self)
  inherit ['contents]Store.store as super
  val m_reads = reads
  val m_store = store
  val m_transaction = transaction
  val m_updates = Hashtbl.create 8
  method private domain_id = self#transaction.domain_id
  method private merge_node node =
    if self#op_exists node#path WRITE || self#op_exists node#path RM || self#op_exists node#path NONE
    then self#store#replace_node node
    else
      match node#contents with
      | Store.Children children | Store.Hack (_, children) -> List.iter (fun child -> self#merge_node child) children
      | _ -> ()
  method private op_add path op =
    match op with
    | WRITE -> if not (self#op_exists path RM) then Hashtbl.replace self#updates path (make_element self#transaction op path)
    | RM -> Hashtbl.replace self#updates path (make_element self#transaction op path)
    | READ -> if not (self#op_exists path READ) then self#reads#add self#transaction path
    | NONE -> Hashtbl.replace self#updates path (make_element self#transaction op path)
  method private op_exists path op =
    match op with
    | WRITE | RM | NONE -> (try (Hashtbl.find self#updates path).operation = op with Not_found -> false)
    | READ -> (try List.exists (fun op -> op.transaction = self#transaction) (self#reads#path_operations path) with Not_found -> false)
  method private reads = m_reads
  method private store = m_store
  method private transaction = m_transaction
  method private updates = m_updates
  method changed_nodes = Hashtbl.fold (fun path element nodes -> element :: nodes) self#updates []
  method create_node path =
    if not (self#op_exists path WRITE) then self#op_add path WRITE;
    super#create_node path
  method merge = self#merge_node self#root
  method node_exists path =
    if self#op_exists path WRITE || self#op_exists path RM || self#op_exists path NONE then super#node_exists path else self#store#node_exists path
  method read_node path =
    if self#op_exists path WRITE || self#op_exists path RM || self#op_exists path NONE
    then super#read_node path
    else (
      self#op_add path READ;
      self#store#read_node path
    )
  method remove_node path =
    let parent_path = Store.parent_path path in
    if self#op_exists parent_path WRITE || self#op_exists parent_path RM || self#op_exists parent_path NONE
    then (
      super#remove_node path;
      self#op_add path RM
    )
    else (
      if not (super#node_exists parent_path)
      then (
        super#create_node parent_path;
        let contents =
          (match (self#store#read_node parent_path) with
            | Store.Children _ -> Store.Children []
            | Store.Hack (value, _) -> Store.Hack (value, [])
            | contents -> contents) in
        (super#get_node parent_path)#set_contents contents
      );
      let self_parent_node = self#get_node parent_path in
      match self_parent_node#contents with
      | Store.Children self_parent_children | Store.Hack (_, self_parent_children) -> (
            (match self#store#read_node parent_path with
              | Store.Children store_parent_children | Store.Hack (_, store_parent_children) -> List.iter (fun store_parent_child -> if not (List.exists (fun self_parent_child -> Store.compare self_parent_child store_parent_child = 0) self_parent_children) then ignore (self_parent_node#add_child store_parent_child)) store_parent_children
              | Store.Empty -> ()
              | Store.Value _ -> raise (Constants.Xs_error (Constants.EINVAL, "Transaction.transaction_store#remove_node", path)));
            self_parent_node#remove_child path;
            self#op_add path RM;
            self#op_add parent_path NONE
          )
      | _ -> raise (Constants.Xs_error (Constants.EINVAL, "Transaction.transaction_store#remove_node", path))
    )
  method write_node path (contents : 'contents) =
    if self#op_exists path WRITE || self#op_exists path RM || self#op_exists path NONE
    then (
      if not (super#node_exists path) then super#create_node path;
      self#op_add path WRITE;
      super#write_node path contents
    )
    else if self#store#node_exists path
    then (
      self#create_node path;
      super#write_node path contents
    )
    else raise (Constants.Xs_error (Constants.EINVAL, "Transaction.transaction_store#write_node", path))
end

class ['contents] transactions (store : 'contents Store.store) =
object (self)
  val m_base_store = store
  val m_num_transactions = Hashtbl.create 8
  val m_reads = new transaction_reads
  val m_transaction_changed_domains = Transaction_hashtbl.create 8
  val m_transaction_ids = Hashtbl.create 8
  val m_transactions = Transaction_hashtbl.create 8
  method private add transaction store =
    if not (Transaction_hashtbl.mem self#transactions transaction)
    then (
      Transaction_hashtbl.add self#transactions transaction (new transaction_store transaction store self#reads);
      Transaction_hashtbl.add self#transaction_changed_domains transaction [ { id = transaction.domain_id; entries = 0 } ];
      Hashtbl.replace self#num_transactions transaction.domain_id (try succ (self#num_transactions_for_domain transaction.domain_id) with Not_found -> 1);
    )
  method private num_transactions = m_num_transactions
  method private reads = m_reads
  method private remove transaction =
    self#reads#remove_transaction_operations transaction;
    Transaction_hashtbl.remove self#transactions transaction;
    Transaction_hashtbl.remove self#transaction_changed_domains transaction;
    Hashtbl.replace self#num_transactions transaction.domain_id (pred (self#num_transactions_for_domain transaction.domain_id))
  method private transaction_changed_domains = m_transaction_changed_domains
  method private transaction_ids = m_transaction_ids
  method private transaction_store transaction = Transaction_hashtbl.find self#transactions transaction
  method private transactions = m_transactions
  method private validate transaction =
    try not (List.fold_left (fun modified op -> if equal op.transaction transaction then op.modified || modified else modified) false (self#reads#transaction_operations transaction))
    with _ -> true
  method base_store = m_base_store
  method commit transaction =
    if self#validate transaction
    then (
      let tstore = self#transaction_store transaction in
      let changed_nodes = tstore#changed_nodes in
      self#invalidate_nodes changed_nodes;
      tstore#merge;
      self#remove transaction;
      changed_nodes
    )
    else (
      self#remove transaction;
      raise Not_found
    )
  method domain_entries transaction = Transaction_hashtbl.find self#transaction_changed_domains transaction
  method domain_entry_decr (transaction : tr) domain_id =
    try
      let domain_entry = List.find (fun entry -> entry.id = domain_id) (self#domain_entries transaction) in
      let new_domain_entry = { id = domain_id; entries = pred domain_entry.entries } in
      Transaction_hashtbl.replace self#transaction_changed_domains transaction (new_domain_entry :: (List.filter (fun entry -> entry.id <> domain_id) (self#domain_entries transaction)))
    with Not_found ->
        let new_domain_entry = { id = domain_id; entries = (- 1) } in
        Transaction_hashtbl.replace self#transaction_changed_domains transaction (new_domain_entry :: (self#domain_entries transaction))
  method domain_entry_incr (transaction : tr) domain_id =
    try
      let domain_entry = List.find (fun entry -> entry.id = domain_id) (self#domain_entries transaction) in
      let new_domain_entry = { id = domain_id; entries = succ domain_entry.entries } in
      Transaction_hashtbl.replace self#transaction_changed_domains transaction (new_domain_entry :: (List.filter (fun entry -> entry.id <> domain_id) (self#domain_entries transaction)))
    with Not_found ->
        let new_domain_entry = { id = domain_id; entries = 1 } in
        Transaction_hashtbl.replace self#transaction_changed_domains transaction (new_domain_entry :: (self#domain_entries transaction))
  method exists transaction = Transaction_hashtbl.mem self#transactions transaction
  method invalidate path = try List.iter (fun op -> op.modified <- true) (self#reads#path_operations path) with Not_found -> ()
  method invalidate_nodes nodes = List.iter (fun node -> self#invalidate node.path) nodes
  method new_transaction (domain : Domain.domain) store =
    if not (Hashtbl.mem self#transaction_ids domain#id) then Hashtbl.add self#transaction_ids domain#id 1l;
    let transaction_id = Hashtbl.find self#transaction_ids domain#id in
    let transaction = make domain#id transaction_id in
    Hashtbl.replace self#transaction_ids domain#id (Int32.succ transaction_id);
    if not (Transaction_hashtbl.mem self#transactions transaction) && transaction.transaction_id <> 0l
    then (self#add transaction store; transaction)
    else self#new_transaction domain store
  method num_transactions_for_domain domain_id = try Hashtbl.find self#num_transactions domain_id with Not_found -> 0
  method remove_domain (domain : Domain.domain) =
    Transaction_hashtbl.iter (fun transaction store -> if transaction.domain_id = domain#id then self#remove transaction) self#transactions;
    Hashtbl.remove self#num_transactions domain#id;
    Hashtbl.remove self#transaction_ids domain#id
  method store transaction = try ((self#transaction_store transaction) :> 'contents Store.store) with Not_found -> self#base_store
end
