/*
 * This file and its contents are licensed under the Timescale License.
 * Please see the included NOTICE for copyright information and
 * LICENSE-TIMESCALE for a copy of the license.
 */

#include <postgres.h>
#include <utils/hsearch.h>
#include <utils/snapmgr.h>

#include "compat/compat.h"

#include "continuous_aggs/insert.h"
#include "debug_point.h"
#include "guc.h"
#include "invalidation.h"
#include "partitioning.h"

/*
 * When tuples in a hypertable that has a continuous aggregate are modified, the
 * lowest modified value and the greatest modified value must be tracked over
 * the course of a transaction or statement. At the end of the statement these
 * values will be inserted into the proper cache invalidation log table for
 * their associated hypertable if they are below the speculative materialization
 * watermark (or, if in REPEATABLE_READ isolation level or higher, they will be
 * inserted no matter what as we cannot see if a materialization transaction has
 * started and moved the watermark during our transaction in that case).
 *
 * We accomplish this at the transaction level by keeping a hash table of each
 * hypertable that has been modified in the transaction and the lowest and
 * greatest modified values. The hashtable will be updated via ModifyHypertable
 * for every row that is inserted, updated or deleted.
 * We use a hashtable because we need to keep track of this on a per hypertable
 * basis and multiple can have tuples modified during a single transaction.
 * (And if we move to per-chunk cache-invalidation it makes it even easier).
 *
 */
typedef struct ContinuousAggsCacheInvalEntry
{
	Oid chunk_relid;
	int32 hypertable_id;
	Dimension hypertable_open_dimension;
	AttrNumber open_dimension_attno;
	bool value_is_set;
	int64 lowest_modified_value;
	int64 greatest_modified_value;
} ContinuousAggsCacheInvalEntry;

static int64 get_lowest_invalidated_time_for_hypertable(int32 hypertable_id);

#define CA_CACHE_INVAL_INIT_HTAB_SIZE 64

static HTAB *continuous_aggs_cache_inval_htab = NULL;
static MemoryContext continuous_aggs_invalidation_mctx = NULL;

static int64 tuple_get_time(Dimension *d, HeapTuple tuple, AttrNumber col, TupleDesc tupdesc);
static inline void cache_inval_entry_init(ContinuousAggsCacheInvalEntry *cache_entry,
										  int32 hypertable_id, Oid chunk_relid);
static inline void update_cache_entry(ContinuousAggsCacheInvalEntry *cache_entry, int64 timeval);
static void cache_inval_entry_write(ContinuousAggsCacheInvalEntry *entry);
static void cache_inval_cleanup(void);
static void cache_inval_htab_write(void);
static void continuous_agg_xact_invalidation_callback(XactEvent event, void *arg);
static ScanTupleResult invalidation_tuple_found(TupleInfo *ti, void *min);

static void
cache_inval_init()
{
	HASHCTL ctl;

	Assert(continuous_aggs_invalidation_mctx == NULL);

	continuous_aggs_invalidation_mctx = AllocSetContextCreate(TopTransactionContext,
															  "ContinuousAggsInvalidationCtx",
															  ALLOCSET_DEFAULT_SIZES);

	memset(&ctl, 0, sizeof(ctl));
	ctl.keysize = sizeof(int32);
	ctl.entrysize = sizeof(ContinuousAggsCacheInvalEntry);
	ctl.hcxt = continuous_aggs_invalidation_mctx;

	continuous_aggs_cache_inval_htab = hash_create("TS Continuous Aggs Cache Inval",
												   CA_CACHE_INVAL_INIT_HTAB_SIZE,
												   &ctl,
												   HASH_ELEM | HASH_BLOBS | HASH_CONTEXT);
};

static int64
tuple_get_time(Dimension *d, HeapTuple tuple, AttrNumber col, TupleDesc tupdesc)
{
	Datum datum;
	bool isnull;
	Oid dimtype;

	datum = heap_getattr(tuple, col, tupdesc, &isnull);
	/*
	 * Since this is the value of the primary partitioning column and we require
	 * partitioning columns to be NOT NULL we should never see a NULL here.
	 */
	Ensure(!isnull, "primary partition column cannot be NULL");

	Assert(d->type == DIMENSION_TYPE_OPEN);

	dimtype = ts_dimension_get_partition_type(d);

	return ts_time_value_to_internal(datum, dimtype);
}

static inline void
cache_inval_entry_init(ContinuousAggsCacheInvalEntry *cache_entry, int32 hypertable_id,
					   Oid chunk_relid)
{
	Cache *ht_cache = ts_hypertable_cache_pin();
	Hypertable *ht = ts_hypertable_cache_get_entry_by_id(ht_cache, hypertable_id);
	Ensure(ht, "could not find hypertable with id %d", hypertable_id);

	const Dimension *open_dim = hyperspace_get_open_dimension(ht->space, 0);
	Ensure(open_dim, "hypertable %d has no open partitioning dimension", hypertable_id);

	cache_entry->chunk_relid = chunk_relid;
	cache_entry->hypertable_id = hypertable_id;
	cache_entry->hypertable_open_dimension = *open_dim;
	cache_entry->open_dimension_attno = get_attnum(chunk_relid, NameStr(open_dim->fd.column_name));
	cache_entry->value_is_set = false;
	cache_entry->lowest_modified_value = INVAL_POS_INFINITY;
	cache_entry->greatest_modified_value = INVAL_NEG_INFINITY;
	ts_cache_release(&ht_cache);
}

static inline void
update_cache_entry(ContinuousAggsCacheInvalEntry *cache_entry, int64 timeval)
{
	cache_entry->value_is_set = true;
	if (timeval < cache_entry->lowest_modified_value)
		cache_entry->lowest_modified_value = timeval;
	if (timeval > cache_entry->greatest_modified_value)
		cache_entry->greatest_modified_value = timeval;
}

/*
 * Used by direct compress invalidation
 */
void
continuous_agg_invalidate_range(int32 hypertable_id, Oid chunk_relid, int64 start, int64 end)
{
	ContinuousAggsCacheInvalEntry *cache_entry;
	bool found;

	if (!continuous_aggs_cache_inval_htab)
		cache_inval_init();

	cache_entry = (ContinuousAggsCacheInvalEntry *)
		hash_search(continuous_aggs_cache_inval_htab, &hypertable_id, HASH_ENTER, &found);

	if (!found)
		cache_inval_entry_init(cache_entry, hypertable_id, chunk_relid);

	cache_entry->value_is_set = true;
	Assert(start <= end);
	if (start < cache_entry->lowest_modified_value)
		cache_entry->lowest_modified_value = start;
	if (end > cache_entry->greatest_modified_value)
		cache_entry->greatest_modified_value = end;
}

void
continuous_agg_dml_invalidate(int32 hypertable_id, Relation chunk_rel, HeapTuple chunk_tuple,
							  HeapTuple chunk_newtuple, bool update)
{
	Assert(!ts_guc_enable_cagg_wal_based_invalidation);
	ContinuousAggsCacheInvalEntry *cache_entry;
	bool found;
	int64 timeval;
	Oid chunk_relid = chunk_rel->rd_id;

	/* On first call, init the mctx and hash table */
	if (!continuous_aggs_cache_inval_htab)
		cache_inval_init();

	cache_entry = (ContinuousAggsCacheInvalEntry *)
		hash_search(continuous_aggs_cache_inval_htab, &chunk_relid, HASH_ENTER, &found);

	if (!found)
		cache_inval_entry_init(cache_entry, hypertable_id, chunk_relid);

	timeval = tuple_get_time(&cache_entry->hypertable_open_dimension,
							 chunk_tuple,
							 cache_entry->open_dimension_attno,
							 RelationGetDescr(chunk_rel));

	update_cache_entry(cache_entry, timeval);

	if (!update)
		return;

	/* on update we need to invalidate the new time value as well as the old one */
	timeval = tuple_get_time(&cache_entry->hypertable_open_dimension,
							 chunk_newtuple,
							 cache_entry->open_dimension_attno,
							 RelationGetDescr(chunk_rel));

	update_cache_entry(cache_entry, timeval);
}

static void
cache_inval_entry_write(ContinuousAggsCacheInvalEntry *entry)
{
	int64 liv;

	if (!entry->value_is_set)
		return;

	/* The materialization worker uses a READ COMMITTED isolation level by default. Therefore, if we
	 * use a stronger isolation level, the isolation threshold could update without us seeing the
	 * new value. In order to prevent serialization errors, we always append invalidation entries in
	 * the case when we're using a strong enough isolation level that we won't see the new
	 * threshold. The materializer can handle invalidations that are beyond the threshold
	 * gracefully.
	 */
	if (IsolationUsesXactSnapshot())
	{
		invalidation_hyper_log_add_entry(entry->hypertable_id,
										 entry->lowest_modified_value,
										 entry->greatest_modified_value);
		return;
	}

	liv = get_lowest_invalidated_time_for_hypertable(entry->hypertable_id);

	if (entry->lowest_modified_value < liv)
		invalidation_hyper_log_add_entry(entry->hypertable_id,
										 entry->lowest_modified_value,
										 entry->greatest_modified_value);
};

static void
cache_inval_cleanup(void)
{
	Assert(continuous_aggs_cache_inval_htab != NULL);
	hash_destroy(continuous_aggs_cache_inval_htab);
	MemoryContextDelete(continuous_aggs_invalidation_mctx);

	continuous_aggs_cache_inval_htab = NULL;
	continuous_aggs_invalidation_mctx = NULL;
};

static void
cache_inval_htab_write(void)
{
	HASH_SEQ_STATUS hash_seq;
	ContinuousAggsCacheInvalEntry *current_entry;
	Catalog *catalog;

	if (hash_get_num_entries(continuous_aggs_cache_inval_htab) == 0)
		return;

	catalog = ts_catalog_get();

	/* The invalidation threshold must remain locked until the end of
	 * the transaction to ensure the materializer will see our updates,
	 * so we explicitly lock it here
	 */
	LockRelationOid(catalog_get_table_id(catalog, CONTINUOUS_AGGS_INVALIDATION_THRESHOLD),
					AccessShareLock);

	hash_seq_init(&hash_seq, continuous_aggs_cache_inval_htab);
	while ((current_entry = hash_seq_search(&hash_seq)) != NULL)
		cache_inval_entry_write(current_entry);
};

/*
 * We use TopTransactionContext for our cached invalidations.
 * We need to make sure cache_inval_cleanup() is always called after cache_inval_htab_write().
 * We need this memory context to survive the transaction lifetime so that cache_inval_cleanup()
 * does not attempt to tear down memory that has already been freed due to a transaction ending.
 *
 * The order of operations in postgres can be this:
 * CallXactCallbacks(XACT_EVENT_PRE_PREPARE);
 * ...
 * CallXactCallbacks(XACT_EVENT_PREPARE);
 * ...
 * MemoryContextDelete(TopTransactionContext);
 *
 * or that:
 * CallXactCallbacks(XACT_EVENT_PRE_COMMIT);
 * ...
 * CallXactCallbacks(XACT_EVENT_COMMIT);
 * ...
 * MemoryContextDelete(TopTransactionContext);
 *
 * In the case of a 2PC transaction, we need to make sure to apply the invalidations at
 * XACT_EVENT_PRE_PREPARE time, before TopTransactionContext is torn down by PREPARE TRANSACTION.
 * Otherwise, we are unable to call cache_inval_cleanup() without corrupting the memory. For
 * this reason, we also deallocate at XACT_EVENT_PREPARE time.
 *
 * For local transactions we apply the invalidations at XACT_EVENT_PRE_COMMIT time.
 * Similar care is taken of parallel workers and aborting transactions.
 */
static void
continuous_agg_xact_invalidation_callback(XactEvent event, void *arg)
{
	/* Return quickly if we never initialize the hashtable */
	if (!continuous_aggs_cache_inval_htab)
		return;

	switch (event)
	{
		case XACT_EVENT_PRE_PREPARE:
		case XACT_EVENT_PRE_COMMIT:
		case XACT_EVENT_PARALLEL_PRE_COMMIT:
			cache_inval_htab_write();
			break;
		case XACT_EVENT_PREPARE:
		case XACT_EVENT_COMMIT:
		case XACT_EVENT_PARALLEL_COMMIT:
		case XACT_EVENT_ABORT:
		case XACT_EVENT_PARALLEL_ABORT:
			cache_inval_cleanup();
			break;
		default:
			break;
	}
}

void
_continuous_aggs_cache_inval_init(void)
{
	RegisterXactCallback(continuous_agg_xact_invalidation_callback, NULL);
}

void
_continuous_aggs_cache_inval_fini(void)
{
	UnregisterXactCallback(continuous_agg_xact_invalidation_callback, NULL);
}

static ScanTupleResult
invalidation_tuple_found(TupleInfo *ti, void *min)
{
	bool isnull;
	Datum watermark =
		slot_getattr(ti->slot, Anum_continuous_aggs_invalidation_threshold_watermark, &isnull);

	Assert(!isnull);

	if (DatumGetInt64(watermark) < *((int64 *) min))
		*((int64 *) min) = DatumGetInt64(watermark);

	DEBUG_WAITPOINT("invalidation_tuple_found_done");

	/*
	 * Return SCAN_CONTINUE because we check for multiple tuples as an error
	 * condition.
	 */
	return SCAN_CONTINUE;
}

static int64
get_lowest_invalidated_time_for_hypertable(int32 hypertable_id)
{
	int64 min_val = INVAL_POS_INFINITY;
	Catalog *catalog = ts_catalog_get();
	ScanKeyData scankey[1];
	ScannerCtx scanctx;

	PushActiveSnapshot(GetLatestSnapshot());
	ScanKeyInit(&scankey[0],
				Anum_continuous_aggs_invalidation_threshold_pkey_hypertable_id,
				BTEqualStrategyNumber,
				F_INT4EQ,
				Int32GetDatum(hypertable_id));
	scanctx = (ScannerCtx){
		.table = catalog_get_table_id(catalog, CONTINUOUS_AGGS_INVALIDATION_THRESHOLD),
		.index = catalog_get_index(catalog,
								   CONTINUOUS_AGGS_INVALIDATION_THRESHOLD,
								   CONTINUOUS_AGGS_INVALIDATION_THRESHOLD_PKEY),
		.nkeys = 1,
		.scankey = scankey,
		.tuple_found = &invalidation_tuple_found,
		.filter = NULL,
		.data = &min_val,
		.lockmode = AccessShareLock,
		.scandirection = ForwardScanDirection,
		.result_mctx = NULL,

		/* We need to define a custom snapshot for this scan. The default snapshot (SNAPSHOT_SELF)
		   reads data of all committed transactions, even if they have started after our scan. If a
		   parallel session updates the scanned value and commits during a scan, we end up in a
		   situation where we see the old and the new value. This causes ts_scanner_scan_one() to
		   fail. */
		.snapshot = GetActiveSnapshot(),
	};

	/* If we don't find any invalidation threshold watermark, then we've never done any
	 * materialization we'll treat this as if the invalidation timestamp is at min value, since the
	 * first materialization needs to scan the entire table anyway; the invalidations are redundant.
	 */
	if (!ts_scanner_scan_one(&scanctx, false, CAGG_INVALIDATION_THRESHOLD_NAME))
		min_val = INVAL_NEG_INFINITY;
	PopActiveSnapshot();

	return min_val;
}
