## --- Estimate DFM --- 
## First of two-step process to estimate DFM (PCA then feed factors into VAR)
## Author: P Hendy

# Estimate factors with PCA and run Bai and NG IC
get.factor.model <- function(x, max.q=NULL, q=NULL, bn=TRUE, bn.op=2, normalisation=TRUE, named_factor=FALSE){
  T <- ncol(x); n <- nrow(x)
  cnt <- min(n, T)
  if(is.null(max.q)) max.q <- round(sqrt(cnt))
  
  if(normalisation){
    #mx <- matrix(rep(apply(x, 1, mean), each=T), byrow=TRUE, nrow=n)
    #x <- x-mx
    #sdx <- apply(x, 1, sd)
    #x <- x/sdx
    #x <- t(scale(t(x)))
  } else{
    mx <- rep(0, n); sdx <- rep(1, n)
  }

  if (named_factor==T) {
    xx <- x%*%t(x)/(T-1)
    eig <- eigen(xx, symmetric=TRUE)
    lam <- eig$vectors[, 1:(cnt), drop=FALSE]
    lam[1:4, 1:4] <- diag(4)
    lam[1:4, 5:length(lam[1,])] <- 0
    f <- t(lam)%*%x
    correlsq <- sweep(t(lam^2), 1, eig$values[1:(cnt), drop=F], FUN='*')
    correlsq <- sweep(correlsq, 2, apply(x, 1, sd)^2, FUN='/')
  } else {
    xx <- x%*%t(x)/(T-1)
    eig <- eigen(xx, symmetric=TRUE)
    lam <- eig$vectors[, 1:(cnt), drop=FALSE]*sqrt(n)
    f <- t(eig$vectors[, 1:(cnt), drop=FALSE])%*%x/sqrt(n)
    correlsq <- sweep(t(lam^2/n), 1, eig$values[1:(cnt), drop=F], FUN='*')
    correlsq <- sweep(correlsq, 2, apply(x, 1, sd)^2, FUN='/')
  }
  
  if(bn){
    ic <- rep(0, 1+max.q)
    ic[1] <- (bn.op <= 4)*log(mean(x^2)) + (bn.op==5)*mean(x^2)
    l <- 1
    while(l<=max.q){
      hchi <- lam[, 1:l, drop=FALSE]%*%f[1:l, , drop=FALSE]
      ic[l+1] <- (bn.op <= 4)*log(mean((x-hchi)^2)) +
        (bn.op==1)*l*(n+T)/(n*T)*log(n*T/(n+T)) +
        (bn.op==2)*l*(n+T)/(n*T)*log(cnt) +
        (bn.op==3)*l*log(cnt)/cnt +
        (bn.op==4)*l*((n+T-l)*log(n*T)/(n*T) + (n+T)/(n*T)*log(cnt))/2 +
        (bn.op==5)*(mean((x-hchi)^2)+l*mean((x-hchi)^2)*(n+T-l)*log(n*T)/(n*T))
      l <- l+1
    }
    q.hat <- which(ic==min(ic))-1
  } else{
    ic <- rep(0, max.q)
    q.hat <- q
  }
  
  return(list(lam = lam, f = f, norm.x=x, q.hat=q.hat, max.q=max.q, ic=ic, correlsq = correlsq))
}

# Robustness for post-93
#df_global_qtr_stationary_nomiss <- df_global_qtr_stationary_nomiss %>% filter(Date>="1993-01-01")

# Transform data to matrix and transpose
ts_factor_data <- t(df_global_qtr_stationary_nomiss %>% filter(Date<="2019-12-31") %>% dplyr::select(-Date) %>% scale())

# Get factors then merge back in dates and graph ts
factors <- get.factor.model(ts_factor_data, bn=TRUE, normalisation=TRUE)
factors_ts <- as.data.frame(factors$f) %>% head(as.numeric(factors$q.hat))
factors_ts <- t(factors_ts) %>% cbind(df_global_qtr_stationary_nomiss$Date %>% head(-11))
factors_ts <- as.data.frame(factors_ts) %>% rename("Date" = "V4", "activity_factor" = "1", "prices_factor" = "2", "financial_factor" = "3")#, "interest_factor" = "4")
factors_ts$Date <- as.Date(factors_ts$Date)

factors_ts$financial_factor <- -factors_ts$financial_factor

# Plot factor loadings
loadings <- as.data.frame(t(factors$correlsq[1:3,]))
loadings <- apply(loadings, 2, function(x) 100*x)
loadings <- as.data.frame(cbind(rownames(ts_factor_data), loadings)) %>%
  rename("variables" = "V1", "factor2" = "V2", "factor3" = "V3") %>%
  rename("factor1" = "V1")
long_loadings <- pivot_longer(loadings, !variables, names_to="factorno", values_to="values")
long_loadings$values <- as.numeric(long_loadings$values)
long_country_loadings <- long_loadings

# Replicate the same procedure with the R prcomp package as a cross-check
rescaled_data <- scale(df_global_qtr_stationary_nomiss %>% filter(Date <= "2019-12-31") %>% dplyr::select(-Date))
pcas <- prcomp(rescaled_data)
lam <- pcas$rotation
check_factors <- pcas$x
var_explained_df <- data.frame(PC= paste0("PC",1:119),
                               var_explained=(pcas$sdev)^2/sum((pcas$sdev)^2), prop_explained = cumsum(pcas$sdev^2 / sum(pcas$sdev^2)))
var_explained_df %>% head(9) %>%
  ggplot()+
  geom_bar(aes(x=as.factor(PC),y=100*var_explained), stat='identity', fill="lightblue", colour="cadetblue") +
  geom_line(aes(x=as.factor(PC),y=100*prop_explained), stat='identity', group=1, colour="firebrick") +
  labs(title="Proportion of variation explained by factors") +
  ylab("Proportion of variation explained") +
  xlab("Factors") +
  theme_bw() +
  theme(plot.title = element_text(hjust = 0.5, face = "bold")) +
  scale_y_continuous(sec.axis = sec_axis(~. * 1, name = "Cumulative variation explained"))

# Graph 2
write.csv(var_explained_df, file=paste0(data_output_folder, '/Graph 2.csv'))

## Plot a heat map of variable loadings across factors/variables/countries
heatmap_loadings <- pivot_longer(loadings, !variables, names_to="factorno", values_to="values")
heatmap_loadings$values <- as.numeric(heatmap_loadings$values)
heatmap_loadings <- heatmap_loadings %>% mutate(category = ifelse(variables %in% activity_variables, "Activity", ifelse(variables %in% prices_variables, "Prices", ifelse(variables %in% interest_rate_variables, "Interest", ifelse(variables %in% financial_variables, "Financial", ifelse(variables %in% commodity_variables, "commodity_variables", NA)))))) %>%
  mutate(country = ifelse(variables %in% us_variables, "US", ifelse(variables %in% china_variables, "China", ifelse(variables %in% other_ae_variables, "Advanced", ifelse(variables %in% other_em_variables, "Emerging", "Global"))))) %>%
  group_by(factorno, category, country) %>%
  summarise(mean_explained = mean(values))# %>%
#  filter(!(category %in% c("commodity_variables", NA)))
heatmap_loadings$country = factor(heatmap_loadings$country, levels=c("China", "Emerging", "Advanced", "US", "Global"))
heatmap_loadings$factorno = ifelse(heatmap_loadings$factorno=="factor1", "Factor 1", ifelse(heatmap_loadings$factorno=="factor2", "Factor 2", "Factor 3"))

heatmap_loadings_test <- heatmap_loadings %>% ungroup() %>%
  add_row(factorno="Factor 1", category="Activity", country=" ", mean_explained=NA)%>%
  add_row(factorno="Factor 2", category="Activity", country=" ", mean_explained=NA)%>%
  add_row(factorno="Factor 3", category="Activity", country=" ", mean_explained=NA)%>%
  filter(!is.na(country) & !country=="Global")
heatmap_loadings_test$country <- factor(heatmap_loadings_test$country, levels = c("China", "Emerging", "Advanced", "US", " "), ordered = TRUE)

create_heatmap <- function(factor) {
  ggrba(heatmap_loadings_test %>% filter(!country %in% c("Global", " "), category != "commodity_variables", factorno == factor), 
        aes(x = category, y = country, fill = mean_explained)) + 
    geom_tile(color = "white") +
    geom_tile(data = heatmap_loadings_test %>% filter(country == " ", factorno == factor), 
              aes(x = category, y = country, height=0.01), fill = NA, color = NA) +  # Transparent dummy row
    scale_fill_gradientn(colors = viridis_pal()(9), limits = c(0, 42), name="Variation explained (%)") + 
    labs(title = factor) +
    theme(legend.title = element_text(size = 18),
          legend.text = element_text(size = 18),
          axis.text.x = element_text(angle=45, hjust=0.8)) +
    scale_y_discrete(limits = c("China", "Emerging", "Advanced", "US", " "))
}

heatmap1 <- create_heatmap("Factor 1")
heatmap2 <- create_heatmap("Factor 2")
heatmap3 <- create_heatmap("Factor 3")

output_heatmap <- heatmap1 + heatmap2 + heatmap3 + source_rba("Author's calculations") + plot_layout(guides="collect") & theme(legend.position='bottom', legend.direction='horizontal', legend.key.size = unit(1, 'cm'))#+ plot_size_rba(plot.width = grid::unit(380, "mm")) 
# Graph 3
ggsave_rba("output_heatmap.svg", output_heatmap)

# Table B1
write.csv(heatmap_loadings %>% pivot_wider(names_from="country", values_from="mean_explained") %>% arrange(factorno, category) %>% dplyr::select(factorno, category, US, Advanced, Emerging, China, Global), paste0(data_output_folder, "/table1.csv"))

# Generate factor plot
graph_factors_ts <- pivot_longer(factors_ts, !Date, names_to="colfactor", values_to="value") %>% group_by(colfactor) %>%
  mutate(fourq_factor = 0.25*(value + lag(value, 2) + lag(value, 3) + lag(value, 4)))

# Graph 4
write.csv(graph_factors_ts %>% dplyr::select(-value) %>% filter(!is.na(fourq_factor)) %>% pivot_wider(names_from="colfactor", values_from="fourq_factor"),
          file=paste0(data_output_folder, "/Graph 4.csv"))


# Function to return average explained domestic variation by an additional factor
pick_aus_expl <- function(input_factor) {
  if (input_factor %in% c('PC1', 'PC2', 'PC3')) {
    pick_factors_ts <- as.data.frame(check_factors) %>% dplyr::select('PC1', 'PC2', 'PC3') %>%
      cbind(Date=floor_date(factors_ts$Date, unit = "month"))
    pick_var_data <- left_join(pick_factors_ts, df_aus_stationary, by="Date") %>%
      dplyr::select(Date, everything()) %>% filter(Date <= "2019-12-31")
    #var_data <- ts(var_data %>% head(-10))
    
    # Run initial VAR
    pick_favar <- vars::VAR(pick_var_data[,-1], ic="FPE", type="both")
    
    # Impose zero restrictions on effect of Australian variables on international factors
    pick_restriction_mat <- matrix(c(1, 1, 1, 1, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1,
                                0, 0, 0, 0, 1, 1, 1, 1,
                                0, 0, 0, 0, 1, 1, 1, 1,
                                0, 0, 0, 0, 1, 1, 1, 1,  
                                0, 0, 0, 0, 1, 1, 1, 1, 
                                0, 0, 0, 0, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1),
                              nrow = 8, ncol=10)
    pick_restricted_favar <- vars::restrict(pick_favar, method="manual", resmat=pick_restriction_mat)
    
    # Output FEVDs
    pick_restricted_fevd <- vars::fevd(pick_restricted_favar, n.ahead=16)
    
    pick_explained_twi <- as.data.frame(pick_restricted_fevd$twi) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('twi' = 'foreign')
    pick_explained_cpi <- as.data.frame(pick_restricted_fevd$cpi) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('cpi' = 'foreign')
    pick_explained_gdp <- as.data.frame(pick_restricted_fevd$gdp) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('gdp' = 'foreign')
    pick_explained_cash_rate <- as.data.frame(pick_restricted_fevd$cash_rate) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('cash_rate' = 'foreign')
    pick_explained_unemp <- as.data.frame(pick_restricted_fevd$unemp) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('unemp' = 'foreign')
    
    pick_variation_explained <- 0.2*(as.numeric(pick_explained_twi) + as.numeric(pick_explained_cpi) + as.numeric(pick_explained_gdp) + as.numeric(pick_explained_cash_rate) + as.numeric(pick_explained_unemp))
    return(pick_variation_explained)
    
    
  } else {
    pick_factors_ts <- as.data.frame(check_factors) %>% dplyr::select('PC1', 'PC2', 'PC3', input_factor) %>%
      cbind(Date=floor_date(factors_ts$Date, unit = "month"))
    pick_var_data <- left_join(pick_factors_ts, df_aus_stationary, by="Date") %>%
      dplyr::select(Date, everything()) %>% filter(Date <= "2019-12-31")
   
    # Run initial VAR
    pick_favar <- vars::VAR(pick_var_data[,-1], ic="FPE", type="both")
    
    # Impose zero restrictions on effect of Australian variables on international factors
    pick_restriction_mat <- matrix(c(1, 1, 1, 1, 1, 1, 1, 1, 1, 
                                1, 1, 1, 1, 1, 1, 1, 1, 1, 
                                1, 1, 1, 1, 1, 1, 1, 1, 1, 
                                1, 1, 1, 1, 1, 1, 1, 1, 1, 
                                0, 0, 0, 0, 1, 1, 1, 1, 1,
                                0, 0, 0, 0, 1, 1, 1, 1, 1,
                                0, 0, 0, 0, 1, 1, 1, 1, 1,  
                                0, 0, 0, 0, 1, 1, 1, 1, 1, 
                                0, 0, 0, 0, 1, 1, 1, 1, 1, 
                                1, 1, 1, 1, 1, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1, 1),
                              nrow = 9, ncol=11)
    pick_restricted_favar <- vars::restrict(pick_favar, method="manual", resmat=pick_restriction_mat)
    
    # Get FEVDs
    pick_restricted_fevd <- vars::fevd(pick_restricted_favar, n.ahead=16)
    
    # Average across Australian variables at the last horizon
    pick_explained_twi <- as.data.frame(pick_restricted_fevd$twi) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]+.[[4]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('twi' = 'foreign')
    pick_explained_cpi <- as.data.frame(pick_restricted_fevd$cpi) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]+.[[4]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('cpi' = 'foreign')
    pick_explained_gdp <- as.data.frame(pick_restricted_fevd$gdp) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]+.[[4]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('gdp' = 'foreign')
    pick_explained_cash_rate <- as.data.frame(pick_restricted_fevd$cash_rate) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]+.[[4]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('cash_rate' = 'foreign')
    pick_explained_unemp <- as.data.frame(pick_restricted_fevd$unemp) %>% mutate(foreign = .[[1]]+.[[2]]+.[[3]]+.[[4]]) %>%
      tail(1) %>% dplyr::select(foreign) %>%
      rename('unemp' = 'foreign')
    
    pick_variation_explained <- 0.2*(as.numeric(pick_explained_twi) + as.numeric(pick_explained_cpi) + as.numeric(pick_explained_gdp) + as.numeric(pick_explained_cash_rate) + as.numeric(pick_explained_unemp))
    return(pick_variation_explained)
  }
}

# Generate 'scree plot' of r^2s for local projections of Aus variables on each individual factor
aus_scree <- check_factors %>% cbind(df_aus_stationary %>% filter(Date<="2019-12-31") %>% dplyr::select(-Date)) #%>% head(-2))
best_aus_factors <- map(colnames(check_factors), pick_aus_expl)
best_aus_factors_plot <- as.data.frame(unlist(best_aus_factors)) %>%
  rename('fev' = 'unlist(best_aus_factors)') %>%
  cbind(pc=1:119)

ggplot(best_aus_factors_plot, aes(x=pc, y=fev)) + geom_bar(stat='identity', fill='lightblue', colour='cadetblue')

# Graph 8
write.csv(best_aus_factors_plot, paste0(data_output_folder, "/Graph 8.csv"))

# Pick highest
picked_factor <- best_aus_factors_plot %>% arrange(desc(fev)) %>% head(5)%>% dplyr::select(pc)

# Graph category loading of highest explained
pick_factors <- get.factor.model(ts_factor_data, bn=TRUE, normalisation=TRUE)
pick_factors_ts <- as.data.frame(pick_factors$f) %>% head(max(picked_factor))
pick_factors_ts <- t(pick_factors_ts)
pick_factors_ts <- pick_factors_ts[,c(1,2,3,unlist(picked_factor))] 
pick_factors_ts <- as.data.frame(pick_factors_ts) %>% cbind(Date=df_global_qtr_stationary_nomiss %>% filter(Date<"2020-01-01") %>% dplyr::select(Date)) %>% rename("activity_factor" = "1", "prices_factor" = "2", "financial_factor" = "3")

pick_loadings <- as.data.frame(t(pick_factors$correlsq[unlist(picked_factor),]))
pick_loadings <- apply(pick_loadings, 2, function(x) 100*x)
pick_loadings <- as.data.frame(cbind(rownames(ts_factor_data), pick_loadings)) %>%
  rename("variables" = "V1")# %>%
  #rename("new_factor" = "pick_loadings")
pick_long_loadings <- pivot_longer(pick_loadings, !variables, names_to="factorno", values_to="values")
pick_long_loadings$values <- as.numeric(pick_long_loadings$values)
pick_long_country_loadings <- pick_long_loadings
pick_long_loadings <- pick_long_loadings %>% group_by(factorno) %>% mutate(isfactor = ifelse(variables %in% activity_variables, "activity_variables", ifelse(variables %in% prices_variables, "prices_variables", ifelse(variables %in% interest_rate_variables, "interest_rate_variables", ifelse(variables %in% financial_variables, "financial_variables", ifelse(variables %in% commodity_variables, "commodity_variables", NA)))))) %>% group_by(factorno, isfactor) %>% summarise(ave_loading = mean(values))
pick_long_loadings <- pick_long_loadings %>% filter(!is.na(isfactor))
pick_long_loadings$colour <- pick_long_loadings$factorno

# Graph country loading of highest explained
pick_long_country_loadings <- pick_long_country_loadings %>% group_by(factorno) %>% mutate(isfactor = ifelse(variables %in% us_variables, "us_variables", ifelse(variables %in% china_variables, "china_variables", ifelse(variables %in% other_ae_variables, "other_ae_variables", "other_em_variables")))) %>% group_by(factorno, isfactor) %>% summarise(ave_loading = mean(values))
pick_long_country_loadings <- pick_long_country_loadings %>% filter(!is.na(isfactor))
pick_long_country_loadings$colour <- pick_long_country_loadings$factorno
picked_country_loadings <- pick_long_country_loadings %>% group_by(isfactor) %>% summarise(ave_loading = mean(ave_loading))

ggplot(pick_long_country_loadings, aes(x=isfactor, y=as.numeric(ave_loading))) +
  geom_col() + 
  coord_flip() + 
  guides(fill="none") + theme_bw() +
  theme(text = element_text(size = 20)) 

# Figure 9
write.csv(picked_country_loadings, paste0(data_output_folder, "/Graph 11.csv"))

# China experiments
no_china <- df_global_qtr_stationary_nomiss %>% dplyr::select(-c(china_variables, "Asian.corporate.spread.to.benchmark"))
ts_no_china_factor_data <- t(no_china %>% filter(Date<="2019-12-31") %>% dplyr::select(-Date) %>% scale())

no_china_factors <- get.factor.model(ts_no_china_factor_data, bn=TRUE, normalisation=TRUE)
no_china_factors_ts <- as.data.frame(no_china_factors$f) %>% head(as.numeric(no_china_factors$q.hat))
no_china_factors_ts <- t(no_china_factors_ts) %>% cbind(no_china$Date %>% head(-13))
no_china_factors_ts <- as.data.frame(no_china_factors_ts) %>% rename("Date" = "V4", "activity_factor" = "1", "prices_factor" = "2", "financial_factor" = "3")#, "interest_factor" = "4")
no_china_factors_ts$Date <- as.Date(no_china_factors_ts$Date)

china_gdp <- read_excel(paste0(data_input_folder, "china_gdp.xlsx")) %>%
  mutate(Date = lubridate::ceiling_date(as.Date(Date, origin="1899-12-30"), unit="month")-1) %>%
  select(Date, china_gdp)

add_china <- df_global_qtr_stationary_nomiss %>% dplyr::select(Date, china_ip, China_10y, CN_PCPI_IX)
add_china <- left_join(china_gdp, add_china, by="Date")
w_china <- no_china_factors_ts %>% left_join(add_china, by="Date") %>%
  filter(!is.na(china_gdp))
